|
|
|
@ -137,6 +137,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
template <typename Predicate, typename DevCtx>
|
|
|
|
|
struct AnyDTypeVisitor {
|
|
|
|
|
Predicate predicate_;
|
|
|
|
@ -149,7 +150,7 @@ struct AnyDTypeVisitor {
|
|
|
|
|
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() const {
|
|
|
|
|
void apply()() const {
|
|
|
|
|
auto t = EigenVector<T>::Flatten(tensor_);
|
|
|
|
|
auto o = EigenScalar<bool>::From(*out_);
|
|
|
|
|
// return any of predicate_(t) is true.
|
|
|
|
@ -173,7 +174,7 @@ struct AnyVisitor : public boost::static_visitor<bool> {
|
|
|
|
|
: tensor_(tensor), predicate_(std::move(predicate)) {}
|
|
|
|
|
|
|
|
|
|
template <typename Place>
|
|
|
|
|
bool operator()(const Place& place) const {
|
|
|
|
|
bool apply()(const Place& place) const {
|
|
|
|
|
framework::Tensor out;
|
|
|
|
|
out.Resize({1});
|
|
|
|
|
out.mutable_data<bool>(place);
|
|
|
|
@ -240,6 +241,7 @@ bool TensorContainsInf(const framework::Tensor& tensor) {
|
|
|
|
|
ContainsInfPredicate predicate;
|
|
|
|
|
return Any(tensor, predicate);
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
void TensorToStream(std::ostream& os, const Tensor& tensor,
|
|
|
|
|
const platform::DeviceContext& dev_ctx) {
|
|
|
|
@ -302,7 +304,7 @@ struct DeserializedDataFunctor {
|
|
|
|
|
: buf_(buf), tensor_(tensor), place_(place) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()() {
|
|
|
|
|
void apply() {
|
|
|
|
|
*buf_ = tensor_->mutable_data<T>(place_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|