diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 5467f6d1b2..72c50518af 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, out->mutable_data(expected_kernel_type.place_, in.type()); framework::VisitDataType( - framework::ToDataType(in.type()), + in.type(), CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); out->set_layout(expected_kernel_type.data_layout_); @@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { case mkldnn::memory::data_type::f32: return platform::to_void_cast(tensor.data()); case mkldnn::memory::data_type::s8: - return platform::to_void_cast(tensor.data()); + return platform::to_void_cast(tensor.data()); case mkldnn::memory::data_type::u8: return platform::to_void_cast(tensor.data()); case mkldnn::memory::data_type::s16: @@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, memory::data_type in_type = ToMKLDNNDataType(in.type()); PADDLE_ENFORCE(in_type != memory::data_type::data_undef, - "Input tensor type is not supported: ", in.type().name()); + "Input tensor type is not supported: %s", in.type()); memory::data_type out_type = in_type; auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 90bb206ec6..2479de4fd4 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) { } } -inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { - static const std::map dict{ - {std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT - {std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT - {std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, - {std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, - {std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; - auto iter = dict.find(type); +inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { + static std::unordered_map dict{ + {DataTypeTrait::DataType, MKLDNNDataType::f32}, + {DataTypeTrait::DataType, MKLDNNDataType::s8}, + {DataTypeTrait::DataType, MKLDNNDataType::u8}, + {DataTypeTrait::DataType, MKLDNNDataType::s16}, + {DataTypeTrait::DataType, MKLDNNDataType::s32}}; + auto iter = dict.find(static_cast(type)); if (iter != dict.end()) return iter->second; return MKLDNNDataType::data_undef; } diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index 28f3da88fa..a0248cf3c7 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -26,7 +26,7 @@ struct DataTypeMap { std::unordered_map cpp_to_proto_; std::unordered_map proto_to_cpp_; std::unordered_map proto_to_str_; - std::unordered_map cpp_to_size_; + std::unordered_map proto_to_size_; }; static DataTypeMap* InitDataTypeMap(); @@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map, map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T)); map->cpp_to_proto_.emplace(typeid(T), proto_type); map->proto_to_str_.emplace(static_cast(proto_type), name); - map->cpp_to_size_.emplace(typeid(T), sizeof(T)); + map->proto_to_size_.emplace(static_cast(proto_type), sizeof(T)); } static DataTypeMap* InitDataTypeMap() { @@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() { #define RegType(cc_type, proto_type) \ RegisterType(retv, proto_type, #cc_type) - // NOTE: Add your customize type here. - RegType(float16, proto::VarType::FP16); - RegType(float, proto::VarType::FP32); - RegType(double, proto::VarType::FP64); - RegType(int, proto::VarType::INT32); - RegType(int64_t, proto::VarType::INT64); - RegType(bool, proto::VarType::BOOL); - RegType(size_t, proto::VarType::SIZE_T); - RegType(int16_t, proto::VarType::INT16); - RegType(uint8_t, proto::VarType::UINT8); - RegType(int8_t, proto::VarType::INT8); + _ForEachDataType_(RegType); #undef RegType return retv; @@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) { static_cast(type)); } -size_t SizeOfType(std::type_index type) { - auto it = gDataTypeMap().cpp_to_size_.find(type); - if (it != gDataTypeMap().cpp_to_size_.end()) { +size_t SizeOfType(proto::VarType::Type type) { + auto it = gDataTypeMap().proto_to_size_.find(static_cast(type)); + if (it != gDataTypeMap().proto_to_size_.end()) { return it->second; } - PADDLE_THROW("Not support %s as tensor type", type.name()); + PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type)); } } // namespace framework diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index d5be43b33e..76df78ea5e 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -22,46 +22,59 @@ limitations under the License. */ namespace paddle { namespace framework { +template +struct DataTypeTrait {}; + +// Stub handle for void +template <> +struct DataTypeTrait { + constexpr static auto DataType = proto::VarType::RAW; +}; + +#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \ + callback(cpp_type, ::paddle::framework::proto::VarType::proto_type); + +#define _ForEachDataType_(callback) \ + _ForEachDataTypeHelper_(callback, float, FP32); \ + _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ + _ForEachDataTypeHelper_(callback, double, FP64); \ + _ForEachDataTypeHelper_(callback, int, INT32); \ + _ForEachDataTypeHelper_(callback, int64_t, INT64); \ + _ForEachDataTypeHelper_(callback, bool, BOOL); \ + _ForEachDataTypeHelper_(callback, uint8_t, UINT8); \ + _ForEachDataTypeHelper_(callback, int16_t, INT16); \ + _ForEachDataTypeHelper_(callback, int8_t, INT8) + +#define DefineDataTypeTrait(cpp_type, proto_type) \ + template <> \ + struct DataTypeTrait { \ + constexpr static auto DataType = proto_type; \ + } + +_ForEachDataType_(DefineDataTypeTrait); + +#undef DefineDataTypeTrait + extern proto::VarType::Type ToDataType(std::type_index type); extern std::type_index ToTypeIndex(proto::VarType::Type type); template inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { - switch (type) { - case proto::VarType::FP16: - visitor.template apply(); - break; - case proto::VarType::FP32: - visitor.template apply(); - break; - case proto::VarType::FP64: - visitor.template apply(); - break; - case proto::VarType::INT32: - visitor.template apply(); - break; - case proto::VarType::INT64: - visitor.template apply(); - break; - case proto::VarType::BOOL: - visitor.template apply(); - break; - case proto::VarType::UINT8: - visitor.template apply(); - break; - case proto::VarType::INT16: - visitor.template apply(); - break; - case proto::VarType::INT8: - visitor.template apply(); - break; - default: - PADDLE_THROW("Not supported %d", type); - } +#define VisitDataTypeCallback(cpp_type, proto_type) \ + do { \ + if (type == proto_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _ForEachDataType_(VisitDataTypeCallback); +#undef VisitDataTypeCallback + PADDLE_THROW("Not supported %d", type); } extern std::string DataTypeToString(const proto::VarType::Type type); -extern size_t SizeOfType(std::type_index type); +extern size_t SizeOfType(proto::VarType::Type type); inline std::ostream& operator<<(std::ostream& out, const proto::VarType::Type& type) { out << DataTypeToString(type); diff --git a/paddle/fluid/framework/data_type_test.cc b/paddle/fluid/framework/data_type_test.cc index 54c41c55ba..92639dfc61 100644 --- a/paddle/fluid/framework/data_type_test.cc +++ b/paddle/fluid/framework/data_type_test.cc @@ -26,13 +26,13 @@ TEST(DataType, float16) { Tensor tensor; CPUPlace cpu; - tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); + tensor.mutable_data(cpu, dtype); // test fp16 tensor - EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); + EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16))); // test fp16 size - EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); + EXPECT_EQ(f::SizeOfType(dtype), 2u); // test debug info std::string type = "float16"; diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index e8bf53e160..9eaff1f560 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() { // Reduce All Tensor to trg in CPU ReduceLoDTensor func(lod_tensors, &trg); - VisitDataType(ToDataType(lod_tensors[0]->type()), func); + VisitDataType(lod_tensors[0]->type(), func); for (size_t i = 1; i < local_scopes_.size(); ++i) { auto &scope = diff --git a/paddle/fluid/framework/details/fuse_vars_op_handle.h b/paddle/fluid/framework/details/fuse_vars_op_handle.h index 3f360c510a..b40b01df36 100644 --- a/paddle/fluid/framework/details/fuse_vars_op_handle.h +++ b/paddle/fluid/framework/details/fuse_vars_op_handle.h @@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase { FuseVarsOpHandle(ir::Node *node, Scope *local_scope, const platform::Place &place, const std::unordered_map &inputs_numel, - const std::type_index &var_type) + const proto::VarType::Type var_type) : OpHandleBase(node), local_scope_(local_scope), place_(place), @@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase { Scope *local_scope_; const platform::Place place_; const std::unordered_map inputs_numel_; - const std::type_index type_; + const proto::VarType::Type type_; int64_t total_numel_; }; } // namespace details diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index cb864848b9..85d8abc910 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() { if (!FLAGS_cpu_deterministic) { ReduceLoDTensor func(lod_tensors, out_var->GetMutable()); - VisitDataType(ToDataType(lod_tensors[0]->type()), func); + VisitDataType(lod_tensors[0]->type(), func); } else { // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. @@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() { ->FindVar(out_var_handle->name_) ->GetMutable(); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); - VisitDataType(ToDataType(lod_tensors[0]->type()), func); + VisitDataType(lod_tensors[0]->type(), func); auto trg = out_var->GetMutable(); if (reduce_sum_trg.data() != trg->data()) { diff --git a/paddle/fluid/framework/dlpack_tensor.cc b/paddle/fluid/framework/dlpack_tensor.cc index 04e3f78afe..eaef093ed3 100644 --- a/paddle/fluid/framework/dlpack_tensor.cc +++ b/paddle/fluid/framework/dlpack_tensor.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/dlpack_tensor.h" - +#include "paddle/fluid/framework/data_type.h" namespace paddle { namespace framework { @@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() { return dtype; } -static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { -#define REG_DL_DATA_TYPE(type) \ - { std::type_index(typeid(type)), GetDLDataTypeCode() } - static const std::unordered_map - type_to_dtype_map({ - REG_DL_DATA_TYPE(platform::float16), // NOLINT - REG_DL_DATA_TYPE(float), // NOLINT - REG_DL_DATA_TYPE(double), // NOLINT - REG_DL_DATA_TYPE(int), // NOLINT - REG_DL_DATA_TYPE(int64_t), // NOLINT - REG_DL_DATA_TYPE(bool), // NOLINT - REG_DL_DATA_TYPE(size_t), // NOLINT - REG_DL_DATA_TYPE(int16_t), // NOLINT - REG_DL_DATA_TYPE(uint8_t), // NOLINT - REG_DL_DATA_TYPE(int8_t) // NOLINT - }); +static std::unordered_map CreateDLDataTypeMap() { + static std::unordered_map result; + +#define REG_DL_DATA_TYPE(cpp_type, proto_type) \ + result[static_cast(proto_type)] = GetDLDataTypeCode() + + _ForEachDataType_(REG_DL_DATA_TYPE); +#undef REG_DL_DATA_TYPE + return result; +} + +static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) { + static auto type_to_dtype_map = CreateDLDataTypeMap(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); - auto it = type_to_dtype_map.find(type); - PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s", - type.name()); + auto it = type_to_dtype_map.find(static_cast(type)); + PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d", + type); return it->second; #undef REG_DL_DATA_TYPE } diff --git a/paddle/fluid/framework/dlpack_tensor_test.cc b/paddle/fluid/framework/dlpack_tensor_test.cc index 938b056350..c0a8e1bcdf 100644 --- a/paddle/fluid/framework/dlpack_tensor_test.cc +++ b/paddle/fluid/framework/dlpack_tensor_test.cc @@ -91,23 +91,11 @@ void TestMainLoop() { } } } +TEST(dlpack, test_all) { +#define TestCallback(cpp_type, proto_type) TestMainLoop() -#define PADDLE_DLPACK_TEST(type) \ - TEST(dlpack, test_##type) { TestMainLoop(); } - -using float16 = platform::float16; -PADDLE_DLPACK_TEST(float16); -PADDLE_DLPACK_TEST(float); -PADDLE_DLPACK_TEST(double); -PADDLE_DLPACK_TEST(int); -PADDLE_DLPACK_TEST(int64_t); -PADDLE_DLPACK_TEST(bool); -PADDLE_DLPACK_TEST(size_t); -PADDLE_DLPACK_TEST(int16_t); -PADDLE_DLPACK_TEST(uint8_t); -PADDLE_DLPACK_TEST(int8_t); - -#undef PADDLE_DLPACK_TEST + _ForEachDataType_(TestCallback); +} } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index 3d53511615..f03f39dfc6 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) { std::cout << sstream.str() << std::endl; } -void print_fetch_var(Scope* scope, std::string var_name) { - const LoDTensor& tensor = scope->FindVar(var_name)->Get(); - - if (std::type_index(tensor.type()) == - std::type_index(typeid(platform::float16))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == - std::type_index(typeid(double))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == - std::type_index(typeid(int64_t))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == - std::type_index(typeid(uint8_t))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == - std::type_index(typeid(int16_t))) { - print_lod_tensor(var_name, tensor); - } else if (std::type_index(tensor.type()) == - std::type_index(typeid(int8_t))) { - print_lod_tensor(var_name, tensor); - } else { - VLOG(1) << "print_fetch_var: unrecognized data type:" - << tensor.type().name(); - } - - return; +static void print_fetch_var(Scope* scope, const std::string& var_name) { + auto& tensor = scope->FindVar(var_name)->Get(); + +#define PrintLoDTensorCallback(cpp_type, proto_type) \ + do { \ + if (tensor.type() == proto_type) { \ + print_lod_tensor(var_name, tensor); \ + return; \ + } \ + } while (0) + + _ForEachDataType_(PrintLoDTensorCallback); + VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type(); } void ExecutorThreadWorker::TrainFiles() { diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 9b2eeaf59a..6c8bec32de 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { // only print first ten elements int64_t size = t.numel() < 10 ? t.numel() : 10; for (int64_t i = 0; i < size; ++i) { - if (IsType(t.type())) { + if (t.type() == proto::VarType::FP32) { os << t.data()[i] << " "; - } else if (IsType(t.type())) { + } else if (t.type() == proto::VarType::INT64) { os << t.data()[i] << " "; } else { PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); @@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor( PADDLE_ENFORCE(!lod_tensors.empty()); framework::DDim new_dim = lod_tensors[0]->dims(); - std::type_index new_type = lod_tensors[0]->type(); + auto new_type = lod_tensors[0]->type(); framework::DataLayout new_layout = lod_tensors[0]->layout(); LoD new_lod = lod_tensors[0]->lod(); for (size_t i = 1; i < lod_tensors.size(); ++i) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c6f3254e9f..05ab48412a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -43,10 +43,9 @@ std::vector> kKernelPriority = { proto::VarType::Type GetDataTypeOfVar(const Variable* var) { if (var->IsType()) { - return framework::ToDataType(var->Get().type()); + return var->Get().type(); } else if (var->IsType()) { - return framework::ToDataType( - var->Get().value().type()); + return var->Get().value().type(); } else { PADDLE_THROW("Var should be LoDTensor or SelectedRows"); } @@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { if (UNLIKELY(!tensor.IsInitialized())) { return ""; } - return DataTypeToString(ToDataType(tensor.type())); + return DataTypeToString(tensor.type()); } else if (var->IsType()) { auto tensor = var->Get().value(); if (UNLIKELY(!tensor.IsInitialized())) { return "uninited"; } else { - return DataTypeToString(ToDataType(tensor.type())); + return DataTypeToString(tensor.type()); } } else { return ""; @@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name, if (tensor.memory_size() == 0) { return; } - if (!IsType(tensor.type()) && !IsType(tensor.type())) { + if (tensor.type() != proto::VarType::FP32 && + tensor.type() != proto::VarType::FP64) { return; } PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), @@ -879,7 +879,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( t = &(var->Get().value()); } if (t != nullptr) { - int tmp = static_cast(ToDataType(t->type())); + int tmp = static_cast(t->type()); PADDLE_ENFORCE( tmp == data_type || data_type == -1, "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 62a30815d4..54a818250b 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, if (index < 0) { VLOG(5) << "id " << id << " not in the table, return 0"; framework::VisitDataType( - framework::ToDataType(value_->type()), + value_->type(), TensorFillVisitor(value, i * value_width, value_width, 0.0)); } else { framework::VisitDataType( - framework::ToDataType(value_->type()), + value_->type(), TensorCopyVisitor(value, i * value_width, *value_.get(), index * value_width, value_width)); } diff --git a/paddle/fluid/framework/tensor.cc b/paddle/fluid/framework/tensor.cc index 41566800e5..57335847a1 100644 --- a/paddle/fluid/framework/tensor.cc +++ b/paddle/fluid/framework/tensor.cc @@ -16,7 +16,7 @@ limitations under the License. */ namespace paddle { namespace framework { -extern size_t SizeOfType(std::type_index type); +extern size_t SizeOfType(proto::VarType::Type type); void Tensor::check_memory_size() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); @@ -31,7 +31,7 @@ size_t Tensor::memory_size() const { return holder_ == nullptr ? 0UL : holder_->size() - offset_; } -void* Tensor::mutable_data(platform::Place place, std::type_index type, +void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type, memory::Allocator::Attr attr, size_t requested_size) { type_ = type; diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 71e8badd4b..057fe1f98c 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -67,7 +68,7 @@ class Tensor { friend struct EigenVector; public: - Tensor() : type_(typeid(float)), offset_(0) {} + Tensor() : type_(proto::VarType::FP32), offset_(0) {} /*! Return a pointer to mutable memory block. */ template @@ -88,7 +89,7 @@ class Tensor { memory::Allocator::Attr attr = memory::Allocator::kDefault, size_t requested_size = 0); - void* mutable_data(platform::Place place, std::type_index type, + void* mutable_data(platform::Place place, proto::VarType::Type type, memory::Allocator::Attr attr = memory::Allocator::kDefault, size_t requested_size = 0); @@ -138,7 +139,7 @@ class Tensor { return holder_->place(); } - std::type_index type() const { + proto::VarType::Type type() const { PADDLE_ENFORCE_NOT_NULL( holder_, "Tensor not initialized yet when Tensor::type() is called."); return type_; @@ -161,7 +162,7 @@ class Tensor { private: /*! holds the memory block if allocated. */ std::shared_ptr holder_; - std::type_index type_; + proto::VarType::Type type_; /** * @brief points to elements dimensions. * diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h index 0c9c0d782f..ce3ad18b1f 100644 --- a/paddle/fluid/framework/tensor_impl.h +++ b/paddle/fluid/framework/tensor_impl.h @@ -24,9 +24,8 @@ template inline const T* Tensor::data() const { check_memory_size(); bool valid = - std::is_same::value || type_ == std::type_index(typeid(T)); - PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", - type_.name()); + std::is_same::value || type_ == DataTypeTrait::DataType; + PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + offset_); @@ -38,9 +37,8 @@ template inline T* Tensor::data() { check_memory_size(); bool valid = - std::is_same::value || type_ == std::type_index(typeid(T)); - PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", - type_.name()); + std::is_same::value || type_ == DataTypeTrait::DataType; + PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_); return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } @@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) { static_assert(std::is_pod::value, "T must be POD"); return reinterpret_cast( - mutable_data(place, typeid(T), attr, requested_size)); + mutable_data(place, DataTypeTrait::DataType, attr, requested_size)); } inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index ca1e01c89f..85d15c5d3f 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -186,8 +186,8 @@ struct AnyDTypeVisitor { template inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, const DevCtx& ctx, framework::Tensor* out) { - VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( - predicate, tensor, ctx, out)); + VisitDataType(tensor.type(), AnyDTypeVisitor( + predicate, tensor, ctx, out)); } template @@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, // int32_t size // void* protobuf message proto::VarType::TensorDesc desc; - desc.set_data_type(framework::ToDataType(tensor.type())); + desc.set_data_type(tensor.type()); auto dims = framework::vectorize(tensor.dims()); auto* pb_dims = desc.mutable_dims(); pb_dims->Resize(static_cast(dims.size()), 0); @@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor, tensor->Resize(framework::make_ddim(dims)); void* buf; auto ctx = platform::CPUDeviceContext(); - size_t size = - tensor->numel() * - framework::SizeOfType(framework::ToTypeIndex(desc.data_type())); + size_t size = tensor->numel() * framework::SizeOfType(desc.data_type()); if (platform::is_gpu_place(dev_ctx.GetPlace())) { #ifdef PADDLE_WITH_CUDA Tensor cpu_tensor; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index be51e7fc1f..c751e85158 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, auto type = fetch.type(); auto output = &(outputs->at(i)); output->name = fetchs_[idx]->Input("X")[0]; - if (type == typeid(float)) { + if (type == framework::proto::VarType::FP32) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT32; - } else if (type == typeid(int64_t)) { + } else if (type == framework::proto::VarType::INT64) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT64; } else { diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 4c5b412a2c..3d121e0460 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector *outputs, auto type = fetch.type(); auto output = &(outputs->at(i)); output->name = fetchs_[idx]->Input("X")[0]; - if (type == typeid(float)) { + if (type == framework::DataTypeTrait::DataType) { GetFetchOne(fetch, output); output->dtype = PaddleDType::FLOAT32; - } else if (type == typeid(int64_t)) { + } else if (type == framework::DataTypeTrait::DataType) { GetFetchOne(fetch, output); output->dtype = PaddleDType::INT64; } else { diff --git a/paddle/fluid/inference/api/api_impl_tester.cc b/paddle/fluid/inference/api/api_impl_tester.cc index 014bdc6a37..191225493c 100644 --- a/paddle/fluid/inference/api/api_impl_tester.cc +++ b/paddle/fluid/inference/api/api_impl_tester.cc @@ -36,10 +36,10 @@ namespace paddle { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor pt; - if (t->type() == typeid(int64_t)) { + if (t->type() == framework::proto::VarType::INT64) { pt.data.Reset(t->data(), t->numel() * sizeof(int64_t)); pt.dtype = PaddleDType::INT64; - } else if (t->type() == typeid(float)) { + } else if (t->type() == framework::proto::VarType::INT32) { pt.data.Reset(t->data(), t->numel() * sizeof(float)); pt.dtype = PaddleDType::FLOAT32; } else { diff --git a/paddle/fluid/operators/affine_grid_op.cc b/paddle/fluid/operators/affine_grid_op.cc index 6f7da445fc..1de59a5165 100644 --- a/paddle/fluid/operators/affine_grid_op.cc +++ b/paddle/fluid/operators/affine_grid_op.cc @@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel { library = framework::LibraryType::kCUDNN; } #endif - auto data_type = framework::ToDataType(ctx.Input("Theta")->type()); + auto data_type = ctx.Input("Theta")->type(); return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kAnyLayout, library); } @@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Theta")->type()), - ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType(ctx.Input("Theta")->type(), + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/arg_max_op.cc b/paddle/fluid/operators/arg_max_op.cc index 8174d37358..7fe9a0df74 100644 --- a/paddle/fluid/operators/arg_max_op.cc +++ b/paddle/fluid/operators/arg_max_op.cc @@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( int32_t>, paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, paddle::operators::ArgMaxKernel); diff --git a/paddle/fluid/operators/arg_max_op.cu b/paddle/fluid/operators/arg_max_op.cu index a147d77a9e..85e4f98173 100644 --- a/paddle/fluid/operators/arg_max_op.cu +++ b/paddle/fluid/operators/arg_max_op.cu @@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( int32_t>, paddle::operators::ArgMaxKernel, - paddle::operators::ArgMaxKernel, paddle::operators::ArgMaxKernel); diff --git a/paddle/fluid/operators/arg_min_op.cc b/paddle/fluid/operators/arg_min_op.cc index 41f188029f..23b24735cd 100644 --- a/paddle/fluid/operators/arg_min_op.cc +++ b/paddle/fluid/operators/arg_min_op.cc @@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL( int32_t>, paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, paddle::operators::ArgMinKernel); diff --git a/paddle/fluid/operators/arg_min_op.cu b/paddle/fluid/operators/arg_min_op.cu index 4d02050850..47d7c8b122 100644 --- a/paddle/fluid/operators/arg_min_op.cu +++ b/paddle/fluid/operators/arg_min_op.cu @@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL( int32_t>, paddle::operators::ArgMinKernel, - paddle::operators::ArgMinKernel, paddle::operators::ArgMinKernel); diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index 6257e04b01..d942391b86 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -58,7 +58,7 @@ struct ArrayToLoDFunctor : public boost::static_visitor { ArrayToLoDFunctorImpl functor; functor.dev_ctx_ = dev_ctx; functor.prev_functor_ = this; - framework::VisitDataType(framework::ToDataType(out->type()), functor); + framework::VisitDataType(out->type(), functor); } }; @@ -91,7 +91,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { PADDLE_ENFORCE(!x.empty(), "There's no element in the input array."); int rank = x[0].dims().size(); platform::Place place = x[0].place(); - std::type_index data_type = x[0].type(); + auto data_type = x[0].type(); int64_t batch_size = x[0].dims()[0]; framework::DDim ins_dims = rank > 1 ? framework::slice_ddim(x[0].dims(), 1, rank) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 75fc59125f..b6996be4b0 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -121,9 +121,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } void AttentionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc index f389eab605..0922b03b5f 100644 --- a/paddle/fluid/operators/average_accumulates_op.cc +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -103,9 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("param")->type()), - ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index f66813989c..8b672e09b2 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -72,8 +72,7 @@ class BatchNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("X")->type()); + auto input_data_type = ctx.Input("X")->type(); // By default, the type of the scale, bias, mean, // and var tensors should both be float. (For float or float16 input tensor) // or double (For double input tensor). @@ -81,17 +80,13 @@ class BatchNormOp : public framework::OperatorWithKernel { if (input_data_type == framework::proto::VarType::FP64) { bn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ(bn_param_type, - framework::ToDataType(ctx.Input("Scale")->type()), + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Scale")->type(), "Scale input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, - framework::ToDataType(ctx.Input("Bias")->type()), + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Bias")->type(), "Bias input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, - framework::ToDataType(ctx.Input("Mean")->type()), + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Mean")->type(), "Mean input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( - ctx.Input("Variance")->type()), + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Variance")->type(), "Variance input should be of float type"); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready @@ -413,9 +408,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel { } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), layout, library); } }; diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 0d32cae0e1..ae9765b761 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -145,7 +145,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { LoDTensor* sentenceScores = ctx.Output("SentenceScores"); framework::VisitDataType( - framework::ToDataType(scores->at(0).type()), + scores->at(0).type(), BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores, beam_size, end_id)); } diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 62771d09f1..30f700f1d9 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -282,8 +282,7 @@ class BeamSearchOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::OpKernelType kt = framework::OpKernelType( - framework::ToDataType( - ctx.Input("pre_ids")->type()), + ctx.Input("pre_ids")->type(), platform::CPUPlace()); return kt; } diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc index 9258d7c7e8..f349c51d8a 100644 --- a/paddle/fluid/operators/bpr_loss_op.cc +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -47,9 +47,8 @@ class BprLossOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; @@ -94,9 +93,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 135254ce6b..dd28f82b65 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -48,13 +48,12 @@ class ConditionalOp : public framework::OperatorBase { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { PADDLE_THROW("should have one initialized input as condition"); } - if (!(framework::IsType(ips[0]->type()) && // NOLINT - ips[0]->numel() == 1)) { - PADDLE_THROW( - "condition input's data type should be bool, " - "numel should be 1, actual numel is %d", - ips[0]->numel()); - } + + PADDLE_ENFORCE(ips[0]->type() == framework::proto::VarType::BOOL && + ips[0]->numel() == 1, + "condition input's data type should be bool, " + "numel should be 1, actual numel is %d", + ips[0]->numel()); bool res = false; if (platform::is_gpu_place(ips[0]->place())) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 6c1b2f329a..66f8508f02 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -237,7 +237,7 @@ class WhileGradOp : public framework::OperatorBase { if (var->IsType()) { auto &inside_tensor = var->Get(); framework::AttributeMap attrs; - attrs["dtype"] = framework::ToDataType(inside_tensor.type()); + attrs["dtype"] = inside_tensor.type(); attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["value"] = 0.0f; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index d7b8766288..183850db18 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -95,10 +95,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } #endif - auto input_data_type = - framework::ToDataType(ctx.Input("Input")->type()); - auto filter_data_type = - framework::ToDataType(ctx.Input("Filter")->type()); + auto input_data_type = ctx.Input("Input")->type(); + auto filter_data_type = ctx.Input("Filter")->type(); PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, "input and filter data type should be consistent"); @@ -382,9 +380,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_, customized_type_value); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_, + customized_type_value); } } // namespace operators diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 2fdfc40d19..86a140f152 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -104,9 +104,8 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_); } void Conv2DTransposeOpMaker::Make() { @@ -335,9 +334,8 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_); } } // namespace operators diff --git a/paddle/fluid/operators/crf_decoding_op.cc b/paddle/fluid/operators/crf_decoding_op.cc index c27befe114..81c9e9e543 100644 --- a/paddle/fluid/operators/crf_decoding_op.cc +++ b/paddle/fluid/operators/crf_decoding_op.cc @@ -118,9 +118,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Emission")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("Emission")->type(), + platform::CPUPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/crop_op.cc b/paddle/fluid/operators/crop_op.cc index a2a871efa8..97d20681b8 100644 --- a/paddle/fluid/operators/crop_op.cc +++ b/paddle/fluid/operators/crop_op.cc @@ -51,9 +51,8 @@ class CropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -174,9 +173,7 @@ class CropOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input(framework::GradVarName("Out")) - ->type()), + ctx.Input(framework::GradVarName("Out"))->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index a904dd9130..1968e54b00 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -57,9 +57,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -111,9 +110,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { // is determined by its input "X". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index d2b440d9d2..e7c472f8c0 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -36,9 +36,8 @@ class CTCAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cc b/paddle/fluid/operators/detection/anchor_generator_op.cc index 0c0155a0a9..f2984d1af2 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cc +++ b/paddle/fluid/operators/detection/anchor_generator_op.cc @@ -53,8 +53,7 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + ctx.Input("Input")->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/bipartite_match_op.cc b/paddle/fluid/operators/detection/bipartite_match_op.cc index c23b65fe4d..b7da1261a8 100644 --- a/paddle/fluid/operators/detection/bipartite_match_op.cc +++ b/paddle/fluid/operators/detection/bipartite_match_op.cc @@ -45,9 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("DistMat")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("DistMat")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/density_prior_box_op.cc b/paddle/fluid/operators/detection/density_prior_box_op.cc index 1012ba3652..cacd47ed4a 100644 --- a/paddle/fluid/operators/detection/density_prior_box_op.cc +++ b/paddle/fluid/operators/detection/density_prior_box_op.cc @@ -66,8 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.GetPlace()); + ctx.Input("Input")->type(), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 709c2dfc4b..2c46803fd0 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -66,9 +66,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Anchors")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Anchors")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/mine_hard_examples_op.cc b/paddle/fluid/operators/detection/mine_hard_examples_op.cc index 54a4b87ec8..f70e6adb5b 100644 --- a/paddle/fluid/operators/detection/mine_hard_examples_op.cc +++ b/paddle/fluid/operators/detection/mine_hard_examples_op.cc @@ -249,8 +249,7 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("ClsLoss")->type()), - platform::CPUPlace()); + ctx.Input("ClsLoss")->type(), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index f0f8851be0..2395b18148 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -65,8 +65,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input("Scores")->type()), + ctx.Input("Scores")->type(), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index b5cb6a724c..3e75c0394f 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -72,8 +72,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + ctx.Input("Input")->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc index 42c720e701..3796854fe6 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cc +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cc @@ -498,9 +498,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -519,9 +518,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 46fff9d338..dc6c3d5a66 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -78,8 +78,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input("Anchor")->type()), + ctx.Input("Anchor")->type(), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/detection/target_assign_op.cc b/paddle/fluid/operators/detection/target_assign_op.cc index 3670019392..c057c82ce0 100644 --- a/paddle/fluid/operators/detection/target_assign_op.cc +++ b/paddle/fluid/operators/detection/target_assign_op.cc @@ -57,9 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index d7f49a9590..e1d113f854 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -71,8 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input("DetectRes")->type()), + ctx.Input("DetectRes")->type(), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 87bf7c6b15..41644d8cc1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -197,8 +197,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = framework::ToDataType( - ctx.Input(framework::GradVarName("Out"))->type()); + auto input_data_type = + ctx.Input(framework::GradVarName("Out"))->type(); #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 43af83fd69..8aff911141 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -115,9 +115,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -175,9 +174,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index e80249fc87..1ed8a2ddd1 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -79,9 +79,8 @@ framework::OpKernelType FCOp::GetExpectedKernelType( library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout, library); } void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { @@ -111,9 +110,8 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout, library); } void FCOpMaker::Make() { diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 252f313440..38cb33e790 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -59,9 +59,9 @@ class FillConstantOp : public framework::OperatorBase { if (force_cpu) { auto cpu = platform::CPUPlace(); - tensor->mutable_data(cpu, framework::ToTypeIndex(data_type)); + tensor->mutable_data(cpu, data_type); } else { - tensor->mutable_data(dev_place, framework::ToTypeIndex(data_type)); + tensor->mutable_data(dev_place, data_type); } platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); diff --git a/paddle/fluid/operators/fill_op.cc b/paddle/fluid/operators/fill_op.cc index adc7cb1f9e..a885b301e7 100644 --- a/paddle/fluid/operators/fill_op.cc +++ b/paddle/fluid/operators/fill_op.cc @@ -55,7 +55,7 @@ class FillOp : public framework::OperatorBase { static_cast(Attr("dtype")); platform::CPUPlace cpu; auto force_cpu = Attr("force_cpu"); - out.mutable_data(force_cpu ? cpu : place, framework::ToTypeIndex(dtype)); + out.mutable_data(force_cpu ? cpu : place, dtype); framework::LoDTensor tensor; @@ -64,7 +64,7 @@ class FillOp : public framework::OperatorBase { } else { // Always make tensor in CPU memory. tensor.Resize(out.dims()); - tensor.mutable_data(cpu, framework::ToTypeIndex(dtype)); + tensor.mutable_data(cpu, dtype); } framework::VisitDataType( diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc index 3771aac0df..0fbf564b7e 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.cc @@ -135,9 +135,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), ctx.Input("Y")->type(), "The element's type of input should be the same."); - auto input_data_type = - framework::ToDataType(ctx.Input("X")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; @@ -324,9 +323,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type_index = ctx.Input("Y")->type(); - auto input_data_type = framework::ToDataType(input_data_type_index); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Y")->type(), + ctx.GetPlace()); } }; } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc index 1eb6523a2d..f1466f17fe 100644 --- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc @@ -115,8 +115,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( - framework::ToDataType( - ctx.Input("Embeddings")->type()), + ctx.Input("Embeddings")->type(), ctx.device_context()); } diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 25b7ae7c28..4ce67e16dd 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -93,9 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 8021a896ce..c4e752e3f0 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -117,9 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc index 40bba09f3e..b05329cfd0 100644 --- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc @@ -61,9 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } void FusionSeqConvEltAddReluOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc index 17ed9771d0..aaef46de0d 100644 --- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc @@ -67,9 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - return framework::OpKernelType( - framework::ToDataType(ctx.MultiInput("X")[0]->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + ctx.device_context()); } void FusionSeqExpandConcatFCOpMaker::Make() { diff --git a/paddle/fluid/operators/gather_op.cc b/paddle/fluid/operators/gather_op.cc index 95aa9b573c..0a8c0814a7 100644 --- a/paddle/fluid/operators/gather_op.cc +++ b/paddle/fluid/operators/gather_op.cc @@ -42,9 +42,8 @@ class GatherOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -60,9 +59,8 @@ class GatherGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index e76eb6893b..14a2524bd8 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -63,9 +63,9 @@ class GridSampleOp : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; @@ -159,9 +159,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { library_ = framework::LibraryType::kCUDNN; } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - framework::DataLayout::kAnyLayout, library_); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace(), + framework::DataLayout::kAnyLayout, library_); } }; diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 6322659b67..4fa15058f8 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -141,8 +141,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { if (t == nullptr) { PADDLE_THROW("can't find Y@GRAD"); } - return framework::OpKernelType(framework::ToDataType(t->type()), - ctx.GetPlace()); + return framework::OpKernelType(t->type(), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc index 0dbcc442df..a807117115 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc @@ -76,9 +76,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; @@ -163,9 +162,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 4d25822259..93dd3f794f 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -55,8 +55,8 @@ class InterpolateOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; @@ -124,8 +124,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/is_empty_op.cc b/paddle/fluid/operators/is_empty_op.cc index 29b73951bb..ba50bdf34b 100644 --- a/paddle/fluid/operators/is_empty_op.cc +++ b/paddle/fluid/operators/is_empty_op.cc @@ -35,8 +35,7 @@ class IsEmptyOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::OpKernelType kt = framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + ctx.Input("X")->type(), platform::CPUPlace()); return kt; } }; diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc index 7b42efd623..1312eecfa4 100644 --- a/paddle/fluid/operators/isfinite_op.cc +++ b/paddle/fluid/operators/isfinite_op.cc @@ -40,10 +40,9 @@ class OverflowOp : public framework::OperatorWithKernel { int dtype = -1; auto *x_var = ctx.InputVar("X"); if (x_var->IsType()) { - dtype = framework::ToDataType(x_var->Get().type()); + dtype = x_var->Get().type(); } else if (x_var->IsType()) { - dtype = framework::ToDataType( - x_var->Get().value().type()); + dtype = x_var->Get().value().type(); } else { PADDLE_THROW("Cannot find the input data type by all input data"); } diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 14ce1da2e9..f83fe355b8 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -153,8 +153,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel { if (t == nullptr) { PADDLE_THROW("can't find Y@GRAD"); } - return framework::OpKernelType(framework::ToDataType(t->type()), - ctx.GetPlace()); + return framework::OpKernelType(t->type(), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc index ea1ca7f59d..998b7f09c3 100644 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -184,9 +184,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { // is determined by its input "Emission". framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Emission")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("Emission")->type(), + platform::CPUPlace()); } }; @@ -244,9 +243,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input(framework::GradVarName("LogLikelihood")) - ->type()), + ctx.Input(framework::GradVarName("LogLikelihood"))->type(), platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 9d1423915a..e28d199eeb 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -69,7 +69,7 @@ class LoadCombineOp : public framework::OperatorBase { // Get data from fin to tensor DeserializeFromStream(*buffer, tensor, dev_ctx); - auto in_dtype = framework::ToDataType(tensor->type()); + auto in_dtype = tensor->type(); auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index df1edc5c2e..06773d1d0e 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -65,7 +65,7 @@ class LoadOp : public framework::OperatorBase { DeserializeFromStream(fin, tensor, dev_ctx); auto load_as_fp16 = Attr("load_as_fp16"); - auto in_dtype = framework::ToDataType(tensor->type()); + auto in_dtype = tensor->type(); auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; if (in_dtype != out_dtype) { diff --git a/paddle/fluid/operators/lod_reset_op.cc b/paddle/fluid/operators/lod_reset_op.cc index 0d4e84e850..7c8fe5fbd7 100644 --- a/paddle/fluid/operators/lod_reset_op.cc +++ b/paddle/fluid/operators/lod_reset_op.cc @@ -39,9 +39,8 @@ class LoDResetOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -144,9 +143,8 @@ class LoDResetGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index 145d2db118..9b91cf5260 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -72,7 +72,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor { LoDTensorToArrayFunctorImpl func; func.prev_functor_ = this; func.dev_ctx_ = dev_ctx; - framework::VisitDataType(framework::ToDataType(input_.type()), func); + framework::VisitDataType(input_.type(), func); } }; diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index 1b55527fd3..4840a7ac1e 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -63,8 +63,7 @@ class LookupSparseTableOp : public framework::OperatorBase { out_shape[0] = ids_t.numel(); out_t->Resize(out_shape); out_t->mutable_data(cpu, w_t->value().type()); - PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), - framework::proto::VarType::FP32, + PADDLE_ENFORCE_EQ(w_t->value().type(), framework::proto::VarType::FP32, "The sparse table only support FP32"); w_t->Get(ids_t, out_t, true, is_test); out_t->set_lod(ids_t.lod()); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index a3bb2be5c7..06ac31b5f1 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -145,9 +145,8 @@ framework::OpKernelType GetExpectedLRNKernel( } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), + layout_, library_); } } // namespace diff --git a/paddle/fluid/operators/lstm_op.cc b/paddle/fluid/operators/lstm_op.cc index 3225bf9bb6..4a199d681f 100644 --- a/paddle/fluid/operators/lstm_op.cc +++ b/paddle/fluid/operators/lstm_op.cc @@ -96,8 +96,7 @@ class LSTMOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + ctx.Input("Input")->type(), ctx.device_context()); } }; @@ -261,8 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + ctx.Input("Input")->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/lstmp_op.cc b/paddle/fluid/operators/lstmp_op.cc index e398b51480..7a62bc9f82 100644 --- a/paddle/fluid/operators/lstmp_op.cc +++ b/paddle/fluid/operators/lstmp_op.cc @@ -113,8 +113,7 @@ class LSTMPOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + ctx.Input("Input")->type(), ctx.device_context()); } }; @@ -312,8 +311,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.device_context()); + ctx.Input("Input")->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 854c8653ff..e1491a8156 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -77,16 +77,14 @@ template <> void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { - framework::VisitDataType(framework::ToDataType(tensor->type()), - TensorSetConstantCPU(tensor, value)); + framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value)); } template <> void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { - framework::VisitDataType(framework::ToDataType(tensor->type()), - TensorSetConstantCPU(tensor, value)); + framework::VisitDataType(tensor->type(), TensorSetConstantCPU(tensor, value)); } struct TensorSetConstantWithPlace : public boost::static_visitor { diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 9372d63f0b..4645b3ae6e 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -65,7 +65,7 @@ template <> void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { - framework::VisitDataType(framework::ToDataType(tensor->type()), + framework::VisitDataType(tensor->type(), TensorSetConstantGPU(context, tensor, value)); } diff --git a/paddle/fluid/operators/mean_iou_op.cc b/paddle/fluid/operators/mean_iou_op.cc index a60f245f53..bb290046f3 100644 --- a/paddle/fluid/operators/mean_iou_op.cc +++ b/paddle/fluid/operators/mean_iou_op.cc @@ -44,9 +44,8 @@ class MeanIoUOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Predictions")->type()), - ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Predictions")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mean_op.cc b/paddle/fluid/operators/mean_op.cc index 820636defa..35b6d7b5e3 100644 --- a/paddle/fluid/operators/mean_op.cc +++ b/paddle/fluid/operators/mean_op.cc @@ -61,9 +61,7 @@ class MeanGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("X")->type()); - + auto input_data_type = ctx.Input("X")->type(); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/merge_lod_tensor_op.cc b/paddle/fluid/operators/merge_lod_tensor_op.cc index 2dc1467b0d..da7fa1b81d 100644 --- a/paddle/fluid/operators/merge_lod_tensor_op.cc +++ b/paddle/fluid/operators/merge_lod_tensor_op.cc @@ -63,9 +63,7 @@ class MergeLoDTensorOp : public framework::OperatorBase { platform::Place place = dev_place; int64_t batch_size = in_true.dims()[0] + in_false.dims()[0]; - - std::type_index data_type = - in_true.IsInitialized() ? in_true.type() : in_false.type(); + auto data_type = in_true.IsInitialized() ? in_true.type() : in_false.type(); int rank; framework::DDim in_dims; if (in_true.IsInitialized()) { diff --git a/paddle/fluid/operators/metrics/accuracy_op.cc b/paddle/fluid/operators/metrics/accuracy_op.cc index 95aa76bc69..7db6dff297 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cc +++ b/paddle/fluid/operators/metrics/accuracy_op.cc @@ -55,9 +55,8 @@ class AccuracyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Out")->type()), - ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Out")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/metrics/auc_op.cc b/paddle/fluid/operators/metrics/auc_op.cc index 335d4fded4..5e33dd9606 100644 --- a/paddle/fluid/operators/metrics/auc_op.cc +++ b/paddle/fluid/operators/metrics/auc_op.cc @@ -51,9 +51,8 @@ class AucOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Predict")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("Predict")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/metrics/precision_recall_op.cc b/paddle/fluid/operators/metrics/precision_recall_op.cc index 0d733c47dd..1a67b13491 100644 --- a/paddle/fluid/operators/metrics/precision_recall_op.cc +++ b/paddle/fluid/operators/metrics/precision_recall_op.cc @@ -82,9 +82,8 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("MaxProbs")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("MaxProbs")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index 18ad46cb5e..1801f2915e 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -53,9 +53,8 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.MultiInput("X")[0]->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + ctx.device_context()); } }; @@ -123,9 +122,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.MultiInput("X")[0]->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.MultiInput("X")[0]->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index 9f97f7821d..06c35c789f 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -69,9 +69,8 @@ class NCEOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("Input")->type(), + platform::CPUPlace()); } }; @@ -214,9 +213,8 @@ class NCEOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("Input")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 9039d02b67..dd365629fc 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -70,9 +70,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cc b/paddle/fluid/operators/optimizers/adagrad_op.cc index e8d5a9e2c8..bd1bb98e63 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/adagrad_op.cc @@ -59,9 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 5710cda39a..5eae503461 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -75,8 +75,7 @@ class AdamOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); + auto input_data_type = ctx.Input("Param")->type(); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/adamax_op.cc b/paddle/fluid/operators/optimizers/adamax_op.cc index 4b244a76dc..aef1fc972c 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cc +++ b/paddle/fluid/operators/optimizers/adamax_op.cc @@ -76,9 +76,8 @@ class AdamaxOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc index 80278441c0..07899278f9 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cc @@ -64,9 +64,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cc b/paddle/fluid/operators/optimizers/ftrl_op.cc index 1c9e91d9b6..c1a4f5790b 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cc +++ b/paddle/fluid/operators/optimizers/ftrl_op.cc @@ -66,8 +66,7 @@ class FTRLOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); + auto input_data_type = ctx.Input("Param")->type(); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc index 7b07b3b707..9dd9b8afbd 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cc @@ -58,9 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cc b/paddle/fluid/operators/optimizers/proximal_gd_op.cc index dcef4f7be2..fccfc2b458 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cc +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cc @@ -46,9 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::ToDataType(ctx.Input("Param")->type()); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Param")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index a9da21f479..6ef2dacb38 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -511,8 +511,8 @@ class Pad2dOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; @@ -612,8 +612,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pad_constant_like_op.cc b/paddle/fluid/operators/pad_constant_like_op.cc index 685ebc3937..3f827c26fd 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cc +++ b/paddle/fluid/operators/pad_constant_like_op.cc @@ -47,9 +47,8 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Y")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Y")->type(), + ctx.device_context()); } }; @@ -171,9 +170,8 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Y")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Y")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 52b607df74..6259954849 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -99,9 +99,8 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( } #endif - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), + layout_, library_); } void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { @@ -130,7 +129,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( } #endif - auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); + auto input_data_type = ctx.Input("X")->type(); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, "float16 can only be used when CUDNN is used"); diff --git a/paddle/fluid/operators/pool_with_index_op.cc b/paddle/fluid/operators/pool_with_index_op.cc index 873706593e..179ee96e01 100644 --- a/paddle/fluid/operators/pool_with_index_op.cc +++ b/paddle/fluid/operators/pool_with_index_op.cc @@ -71,9 +71,8 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -92,9 +91,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/positive_negative_pair_op.cc b/paddle/fluid/operators/positive_negative_pair_op.cc index 4d865b7f17..99256e408d 100644 --- a/paddle/fluid/operators/positive_negative_pair_op.cc +++ b/paddle/fluid/operators/positive_negative_pair_op.cc @@ -87,9 +87,8 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Score")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Score")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 64d94ab604..62c55c4f55 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,9 +56,8 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -113,9 +112,8 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index e7f1caf4d3..6a5bf17060 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -172,7 +172,7 @@ class TensorPrintOp : public framework::OperatorBase { formater.name = printed_var_name; } if (Attr("print_tensor_type")) { - formater.dtype = printed_tensor.type(); + formater.dtype = framework::ToTypeIndex(printed_tensor.type()); } if (Attr("print_tensor_shape")) { auto &dims = printed_tensor.dims(); diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index 123fa44fa3..cd3bd32adb 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -22,9 +22,8 @@ class RandomCropOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index e17c2ffd39..f771cebd0c 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -99,10 +99,10 @@ void BatchReader::ReadNextImpl(std::vector* out) { out->reserve(out_num); for (size_t j = 0; j < out_num; ++j) { // Merge shape and check date type - std::type_index batch_type = buffer_[0][j].type(); + auto batch_type = buffer_[0][j].type(); framework::DDim batch_shape = buffer_[0][j].dims(); for (size_t i = 1; i < buffer_.size(); ++i) { - std::type_index ins_type = buffer_[i][j].type(); + auto ins_type = buffer_[i][j].type(); framework::DDim ins_shape = buffer_[i][j].dims(); PADDLE_ENFORCE_EQ(batch_type, ins_type); PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 162bfcbb08..a1e02a3fd0 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -414,7 +414,7 @@ class RecurrentGradOp : public RecurrentBase { auto &inside_tensor = cur_scope.FindVar(inside_grad_name) ->Get(); framework::AttributeMap attrs; - attrs["dtype"] = framework::ToDataType(inside_tensor.type()); + attrs["dtype"] = inside_tensor.type(); attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["value"] = 0.0f; diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 500d86fec3..289d848ea1 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -108,9 +108,8 @@ class ReshapeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -189,9 +188,8 @@ class ReshapeGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -322,9 +320,7 @@ class Reshape2GradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input(framework::GradVarName("Out")) - ->type()), + ctx.Input(framework::GradVarName("Out"))->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/rnn_memory_helper_op.cc b/paddle/fluid/operators/rnn_memory_helper_op.cc index 0fb7776fd9..834dd1eabd 100644 --- a/paddle/fluid/operators/rnn_memory_helper_op.cc +++ b/paddle/fluid/operators/rnn_memory_helper_op.cc @@ -99,7 +99,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase { auto &in_var_tensor = in_var->Get(); framework::AttributeMap attrs; - attrs["dtype"] = framework::ToDataType(in_var_tensor.type()); + attrs["dtype"] = in_var_tensor.type(); attrs["shape"] = framework::vectorize2int(in_var_tensor.dims()); attrs["value"] = 0.0f; diff --git a/paddle/fluid/operators/roi_align_op.cc b/paddle/fluid/operators/roi_align_op.cc index 79f189222e..6857b5ed9d 100644 --- a/paddle/fluid/operators/roi_align_op.cc +++ b/paddle/fluid/operators/roi_align_op.cc @@ -62,9 +62,8 @@ class ROIAlignOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -83,9 +82,8 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/roi_pool_op.cc b/paddle/fluid/operators/roi_pool_op.cc index 3f6b2e46c7..e46d92d6fc 100644 --- a/paddle/fluid/operators/roi_pool_op.cc +++ b/paddle/fluid/operators/roi_pool_op.cc @@ -69,9 +69,8 @@ class ROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -90,9 +89,8 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 5b05f757c0..a0b9fa305d 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -75,7 +75,7 @@ class SaveCombineOp : public framework::OperatorBase { // Serialize tensors one by one // Check types to see if a fp16 transformation is required - auto in_dtype = framework::ToDataType(tensor.type()); + auto in_dtype = tensor.type(); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index e79cffcf49..e1c9fd8ff1 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -85,7 +85,7 @@ class SaveOp : public framework::OperatorBase { filename); auto save_as_fp16 = Attr("save_as_fp16"); - auto in_dtype = framework::ToDataType(tensor.type()); + auto in_dtype = tensor.type(); auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; if (in_dtype != out_dtype) { diff --git a/paddle/fluid/operators/scatter_op.cc b/paddle/fluid/operators/scatter_op.cc index c32d2603cf..ad418d51bc 100644 --- a/paddle/fluid/operators/scatter_op.cc +++ b/paddle/fluid/operators/scatter_op.cc @@ -51,9 +51,8 @@ class ScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -70,9 +69,8 @@ class ScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 44b09bf7c2..1754221e77 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -114,9 +114,8 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc index c49d1ccb18..8267c04f9f 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_scatter_op.cc @@ -112,9 +112,8 @@ class SequenceScatterOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; @@ -131,9 +130,8 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc index 6f84023e26..35f49f78ce 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_slice_op.cc @@ -50,9 +50,8 @@ class SequenceSliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -71,9 +70,8 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc index 644a5bebc1..027073e5d7 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_softmax_op.cc @@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + ctx.Input("X")->type(), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; @@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { } std::string data_format = ctx.Attr("data_format"); return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + ctx.Input("X")->type(), ctx.GetPlace(), framework::StringToDataLayout(data_format), library_); } }; diff --git a/paddle/fluid/operators/similarity_focus_op.cc b/paddle/fluid/operators/similarity_focus_op.cc index 9612f82b6d..21871d7656 100644 --- a/paddle/fluid/operators/similarity_focus_op.cc +++ b/paddle/fluid/operators/similarity_focus_op.cc @@ -70,9 +70,8 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index e55462d6cf..789e61b2d3 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -59,9 +59,8 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Input")->type()), - ctx.GetPlace()); + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 091ce4e6e8..bc889a5a04 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -62,8 +62,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { } #endif - auto input_data_type = - framework::ToDataType(ctx.Input("X")->type()); + auto input_data_type = ctx.Input("X")->type(); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); @@ -169,8 +168,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = framework::ToDataType( - ctx.Input(framework::GradVarName("Out"))->type()); + auto input_data_type = + ctx.Input(framework::GradVarName("Out"))->type(); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "float16 can only be used on GPU place"); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 2900221485..0397c7791e 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -131,9 +131,8 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Logits")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Logits")->type(), + ctx.device_context()); } }; @@ -173,8 +172,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input(framework::GradVarName("Loss"))->type()), + ctx.Input(framework::GradVarName("Loss"))->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 7df14158f3..4f717a4355 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -91,9 +91,9 @@ class SumOp : public framework::OperatorWithKernel { continue; } if (dtype == -1) { - dtype = framework::ToDataType(tensor->type()); + dtype = tensor->type(); } else { - PADDLE_ENFORCE_EQ(dtype, framework::ToDataType(tensor->type())); + PADDLE_ENFORCE_EQ(dtype, tensor->type()); } } PADDLE_ENFORCE_NE(dtype, -1, @@ -106,8 +106,8 @@ class SumOp : public framework::OperatorWithKernel { for (auto& var : x_vars) { auto& value = var->Get().value(); if (value.IsInitialized()) { - return framework::OpKernelType(framework::ToDataType(value.type()), - ctx.device_context(), layout, library); + return framework::OpKernelType(value.type(), ctx.device_context(), + layout, library); } } // if input sparse vars are not initialized, use an default kernel type. @@ -118,9 +118,8 @@ class SumOp : public framework::OperatorWithKernel { auto& array = x_var->Get(); for (auto& each : array) { if (each.numel() != 0) { - return framework::OpKernelType(framework::ToDataType(each.type()), - ctx.device_context(), layout, - library); + return framework::OpKernelType(each.type(), ctx.device_context(), + layout, library); } } } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 6eef4c98c4..5b2aad55a4 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -76,10 +76,7 @@ class TensorRTEngineOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { auto input0 = ctx.Inputs("Xs").front(); framework::OpKernelType kt = framework::OpKernelType( - framework::ToDataType(ctx.scope() - .FindVar(input0) - ->GetMutable() - ->type()), + ctx.scope().FindVar(input0)->GetMutable()->type(), ctx.GetPlace()); return kt; } diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index bbd71db606..bc1f59bc1a 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -144,9 +144,8 @@ class Transpose2Op : public TransposeOp { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -194,9 +193,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.Input(framework::GradVarName("Out")) - ->type()), + ctx.Input(framework::GradVarName("Out"))->type(), ctx.device_context()); } }; diff --git a/paddle/fluid/operators/unpool_op.cc b/paddle/fluid/operators/unpool_op.cc index 6d2ccb38f6..11e505d6df 100644 --- a/paddle/fluid/operators/unpool_op.cc +++ b/paddle/fluid/operators/unpool_op.cc @@ -74,9 +74,8 @@ class UnpoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } public: @@ -113,9 +112,8 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } public: diff --git a/paddle/fluid/operators/warpctc_op.cc b/paddle/fluid/operators/warpctc_op.cc index 6a257cebf5..e2ae7caae1 100644 --- a/paddle/fluid/operators/warpctc_op.cc +++ b/paddle/fluid/operators/warpctc_op.cc @@ -56,9 +56,8 @@ class WarpCTCOp : public framework::OperatorWithKernel { } #endif framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Logits")->type()), - ctx.device_context(), layout_, library_); + return framework::OpKernelType(ctx.Input("Logits")->type(), + ctx.device_context(), layout_, library_); } }; @@ -136,9 +135,8 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("Logits")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("Logits")->type(), + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/yolov3_loss_op.cc b/paddle/fluid/operators/yolov3_loss_op.cc index e7597f7324..60508f7ab8 100644 --- a/paddle/fluid/operators/yolov3_loss_op.cc +++ b/paddle/fluid/operators/yolov3_loss_op.cc @@ -64,9 +64,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; @@ -180,9 +179,8 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - platform::CPUPlace()); + return framework::OpKernelType(ctx.Input("X")->type(), + platform::CPUPlace()); } }; diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 7c539d25f6..cbb090adef 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -20,6 +20,7 @@ #include // NOLINT #include #include +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/enforce.h" @@ -28,14 +29,14 @@ namespace paddle { namespace platform { -inline ncclDataType_t ToNCCLDataType(std::type_index type) { - if (type == typeid(float)) { // NOLINT +inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { + if (type == framework::proto::VarType::FP32) { return ncclFloat; - } else if (type == typeid(double)) { // NOLINT + } else if (type == framework::proto::VarType::FP64) { return ncclDouble; - } else if (type == typeid(int)) { // NOLINT + } else if (type == framework::proto::VarType::INT32) { return ncclInt; - } else if (type == typeid(int64_t)) { // NOLINT + } else if (type == framework::proto::VarType::INT64) { return ncclInt64; } else { PADDLE_THROW("Not supported"); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index dca0c01ab2..314ab98625 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -206,7 +206,7 @@ PYBIND11_MODULE(core, m) { .def("_get_float_element", TensorGetElement) .def("_set_double_element", TensorSetElement) .def("_get_double_element", TensorGetElement) - .def("_dtype", [](Tensor &self) { return ToDataType(self.type()); }); + .def("_dtype", [](Tensor &self) { return self.type(); }); py::class_(m, "LoDTensor", R"DOC( LoDTensor is a Tensor with optional LoD information. diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index f67f40f19f..5e91f5b301 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -43,7 +43,7 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; pybind11::buffer_info operator()(const framework::Tensor &tensor) { - if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) { + if (framework::DataTypeTrait::DataType == tensor.type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; std::vector strides;