Change tensor uses proto::VarType::type

test=develop
for_weibo
Yu Yang 6 years ago
parent 5e60906996
commit 9bd70a1e04

@ -85,7 +85,7 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->mutable_data(expected_kernel_type.place_, in.type()); out->mutable_data(expected_kernel_type.place_, in.type());
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(in.type()), in.type(),
CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out)); CastDataLayout(pool.Get(expected_kernel_type.place_), axis, in, out));
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
@ -101,7 +101,7 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
case mkldnn::memory::data_type::f32: case mkldnn::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case mkldnn::memory::data_type::s8:
return platform::to_void_cast(tensor.data<char>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case mkldnn::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s16: case mkldnn::memory::data_type::s16:
@ -144,7 +144,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: %s", in.type());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());

@ -50,14 +50,14 @@ inline DataLayout ToPaddleLayout(const MKLDNNFormat& format) {
} }
} }
inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static const std::map<std::type_index, MKLDNNDataType> dict{ static std::unordered_map<int, MKLDNNDataType> dict{
{std::type_index(typeid(float)), MKLDNNDataType::f32}, // NOLINT {DataTypeTrait<float>::DataType, MKLDNNDataType::f32},
{std::type_index(typeid(char)), MKLDNNDataType::s8}, // NOLINT {DataTypeTrait<int8_t>::DataType, MKLDNNDataType::s8},
{std::type_index(typeid(unsigned char)), MKLDNNDataType::u8}, {DataTypeTrait<uint8_t>::DataType, MKLDNNDataType::u8},
{std::type_index(typeid(int16_t)), MKLDNNDataType::s16}, {DataTypeTrait<int16_t>::DataType, MKLDNNDataType::s16},
{std::type_index(typeid(int32_t)), MKLDNNDataType::s32}}; {DataTypeTrait<int32_t>::DataType, MKLDNNDataType::s32}};
auto iter = dict.find(type); auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second; if (iter != dict.end()) return iter->second;
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }

@ -26,7 +26,7 @@ struct DataTypeMap {
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_; std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
std::unordered_map<int, std::type_index> proto_to_cpp_; std::unordered_map<int, std::type_index> proto_to_cpp_;
std::unordered_map<int, std::string> proto_to_str_; std::unordered_map<int, std::string> proto_to_str_;
std::unordered_map<std::type_index, size_t> cpp_to_size_; std::unordered_map<int, size_t> proto_to_size_;
}; };
static DataTypeMap* InitDataTypeMap(); static DataTypeMap* InitDataTypeMap();
@ -45,7 +45,7 @@ static inline void RegisterType(DataTypeMap* map,
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T)); map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
map->cpp_to_proto_.emplace(typeid(T), proto_type); map->cpp_to_proto_.emplace(typeid(T), proto_type);
map->proto_to_str_.emplace(static_cast<int>(proto_type), name); map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
map->cpp_to_size_.emplace(typeid(T), sizeof(T)); map->proto_to_size_.emplace(static_cast<int>(proto_type), sizeof(T));
} }
static DataTypeMap* InitDataTypeMap() { static DataTypeMap* InitDataTypeMap() {
@ -54,17 +54,7 @@ static DataTypeMap* InitDataTypeMap() {
#define RegType(cc_type, proto_type) \ #define RegType(cc_type, proto_type) \
RegisterType<cc_type>(retv, proto_type, #cc_type) RegisterType<cc_type>(retv, proto_type, #cc_type)
// NOTE: Add your customize type here. _ForEachDataType_(RegType);
RegType(float16, proto::VarType::FP16);
RegType(float, proto::VarType::FP32);
RegType(double, proto::VarType::FP64);
RegType(int, proto::VarType::INT32);
RegType(int64_t, proto::VarType::INT64);
RegType(bool, proto::VarType::BOOL);
RegType(size_t, proto::VarType::SIZE_T);
RegType(int16_t, proto::VarType::INT16);
RegType(uint8_t, proto::VarType::UINT8);
RegType(int8_t, proto::VarType::INT8);
#undef RegType #undef RegType
return retv; return retv;
@ -96,12 +86,12 @@ std::string DataTypeToString(const proto::VarType::Type type) {
static_cast<int>(type)); static_cast<int>(type));
} }
size_t SizeOfType(std::type_index type) { size_t SizeOfType(proto::VarType::Type type) {
auto it = gDataTypeMap().cpp_to_size_.find(type); auto it = gDataTypeMap().proto_to_size_.find(static_cast<int>(type));
if (it != gDataTypeMap().cpp_to_size_.end()) { if (it != gDataTypeMap().proto_to_size_.end()) {
return it->second; return it->second;
} }
PADDLE_THROW("Not support %s as tensor type", type.name()); PADDLE_THROW("Not support %s as tensor type", DataTypeToString(type));
} }
} // namespace framework } // namespace framework

@ -22,46 +22,59 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DataTypeTrait {};
// Stub handle for void
template <>
struct DataTypeTrait<void> {
constexpr static auto DataType = proto::VarType::RAW;
};
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define DefineDataTypeTrait(cpp_type, proto_type) \
template <> \
struct DataTypeTrait<cpp_type> { \
constexpr static auto DataType = proto_type; \
}
_ForEachDataType_(DefineDataTypeTrait);
#undef DefineDataTypeTrait
extern proto::VarType::Type ToDataType(std::type_index type); extern proto::VarType::Type ToDataType(std::type_index type);
extern std::type_index ToTypeIndex(proto::VarType::Type type); extern std::type_index ToTypeIndex(proto::VarType::Type type);
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
switch (type) { #define VisitDataTypeCallback(cpp_type, proto_type) \
case proto::VarType::FP16: do { \
visitor.template apply<platform::float16>(); if (type == proto_type) { \
break; visitor.template apply<cpp_type>(); \
case proto::VarType::FP32: return; \
visitor.template apply<float>(); } \
break; } while (0)
case proto::VarType::FP64:
visitor.template apply<double>(); _ForEachDataType_(VisitDataTypeCallback);
break; #undef VisitDataTypeCallback
case proto::VarType::INT32: PADDLE_THROW("Not supported %d", type);
visitor.template apply<int>();
break;
case proto::VarType::INT64:
visitor.template apply<int64_t>();
break;
case proto::VarType::BOOL:
visitor.template apply<bool>();
break;
case proto::VarType::UINT8:
visitor.template apply<uint8_t>();
break;
case proto::VarType::INT16:
visitor.template apply<int16_t>();
break;
case proto::VarType::INT8:
visitor.template apply<int8_t>();
break;
default:
PADDLE_THROW("Not supported %d", type);
}
} }
extern std::string DataTypeToString(const proto::VarType::Type type); extern std::string DataTypeToString(const proto::VarType::Type type);
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
inline std::ostream& operator<<(std::ostream& out, inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) { const proto::VarType::Type& type) {
out << DataTypeToString(type); out << DataTypeToString(type);

@ -26,13 +26,13 @@ TEST(DataType, float16) {
Tensor tensor; Tensor tensor;
CPUPlace cpu; CPUPlace cpu;
tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); tensor.mutable_data(cpu, dtype);
// test fp16 tensor // test fp16 tensor
EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); EXPECT_EQ(tensor.type(), f::ToDataType(typeid(float16)));
// test fp16 size // test fp16 size
EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); EXPECT_EQ(f::SizeOfType(dtype), 2u);
// test debug info // test debug info
std::string type = "float16"; std::string type = "float16";

@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg); ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope =

@ -33,7 +33,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
FuseVarsOpHandle(ir::Node *node, Scope *local_scope, FuseVarsOpHandle(ir::Node *node, Scope *local_scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_map<std::string, int64_t> &inputs_numel, const std::unordered_map<std::string, int64_t> &inputs_numel,
const std::type_index &var_type) const proto::VarType::Type var_type)
: OpHandleBase(node), : OpHandleBase(node),
local_scope_(local_scope), local_scope_(local_scope),
place_(place), place_(place),
@ -57,7 +57,7 @@ struct FuseVarsOpHandle : public OpHandleBase {
Scope *local_scope_; Scope *local_scope_;
const platform::Place place_; const platform::Place place_;
const std::unordered_map<std::string, int64_t> inputs_numel_; const std::unordered_map<std::string, int64_t> inputs_numel_;
const std::type_index type_; const proto::VarType::Type type_;
int64_t total_numel_; int64_t total_numel_;
}; };
} // namespace details } // namespace details

@ -246,7 +246,7 @@ void ReduceOpHandle::RunImpl() {
if (!FLAGS_cpu_deterministic) { if (!FLAGS_cpu_deterministic) {
ReduceLoDTensor func(lod_tensors, ReduceLoDTensor func(lod_tensors,
out_var->GetMutable<framework::LoDTensor>()); out_var->GetMutable<framework::LoDTensor>());
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
} else { } else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
@ -256,7 +256,7 @@ void ReduceOpHandle::RunImpl() {
->FindVar(out_var_handle->name_) ->FindVar(out_var_handle->name_)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(ToDataType(lod_tensors[0]->type()), func); VisitDataType(lod_tensors[0]->type(), func);
auto trg = out_var->GetMutable<framework::LoDTensor>(); auto trg = out_var->GetMutable<framework::LoDTensor>();
if (reduce_sum_trg.data<void>() != trg->data<void>()) { if (reduce_sum_trg.data<void>() != trg->data<void>()) {

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/dlpack_tensor.h" #include "paddle/fluid/framework/dlpack_tensor.h"
#include "paddle/fluid/framework/data_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -36,26 +36,23 @@ static ::DLDataType GetDLDataTypeCode() {
return dtype; return dtype;
} }
static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { static std::unordered_map<int, ::DLDataType> CreateDLDataTypeMap() {
#define REG_DL_DATA_TYPE(type) \ static std::unordered_map<int, ::DLDataType> result;
{ std::type_index(typeid(type)), GetDLDataTypeCode<type>() }
static const std::unordered_map<std::type_index, ::DLDataType> #define REG_DL_DATA_TYPE(cpp_type, proto_type) \
type_to_dtype_map({ result[static_cast<int>(proto_type)] = GetDLDataTypeCode<cpp_type>()
REG_DL_DATA_TYPE(platform::float16), // NOLINT
REG_DL_DATA_TYPE(float), // NOLINT _ForEachDataType_(REG_DL_DATA_TYPE);
REG_DL_DATA_TYPE(double), // NOLINT #undef REG_DL_DATA_TYPE
REG_DL_DATA_TYPE(int), // NOLINT return result;
REG_DL_DATA_TYPE(int64_t), // NOLINT }
REG_DL_DATA_TYPE(bool), // NOLINT
REG_DL_DATA_TYPE(size_t), // NOLINT static DLDataType GetDLDataTypeFromTypeIndex(proto::VarType::Type type) {
REG_DL_DATA_TYPE(int16_t), // NOLINT static auto type_to_dtype_map = CreateDLDataTypeMap();
REG_DL_DATA_TYPE(uint8_t), // NOLINT
REG_DL_DATA_TYPE(int8_t) // NOLINT
});
static auto type_to_dtype_map_end_it = type_to_dtype_map.end(); static auto type_to_dtype_map_end_it = type_to_dtype_map.end();
auto it = type_to_dtype_map.find(type); auto it = type_to_dtype_map.find(static_cast<int>(type));
PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %s", PADDLE_ENFORCE(it != type_to_dtype_map_end_it, "Unsupported data type %d",
type.name()); type);
return it->second; return it->second;
#undef REG_DL_DATA_TYPE #undef REG_DL_DATA_TYPE
} }

@ -91,23 +91,11 @@ void TestMainLoop() {
} }
} }
} }
TEST(dlpack, test_all) {
#define TestCallback(cpp_type, proto_type) TestMainLoop<cpp_type>()
#define PADDLE_DLPACK_TEST(type) \ _ForEachDataType_(TestCallback);
TEST(dlpack, test_##type) { TestMainLoop<type>(); } }
using float16 = platform::float16;
PADDLE_DLPACK_TEST(float16);
PADDLE_DLPACK_TEST(float);
PADDLE_DLPACK_TEST(double);
PADDLE_DLPACK_TEST(int);
PADDLE_DLPACK_TEST(int64_t);
PADDLE_DLPACK_TEST(bool);
PADDLE_DLPACK_TEST(size_t);
PADDLE_DLPACK_TEST(int16_t);
PADDLE_DLPACK_TEST(uint8_t);
PADDLE_DLPACK_TEST(int8_t);
#undef PADDLE_DLPACK_TEST
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -138,39 +138,19 @@ void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) {
std::cout << sstream.str() << std::endl; std::cout << sstream.str() << std::endl;
} }
void print_fetch_var(Scope* scope, std::string var_name) { static void print_fetch_var(Scope* scope, const std::string& var_name) {
const LoDTensor& tensor = scope->FindVar(var_name)->Get<LoDTensor>(); auto& tensor = scope->FindVar(var_name)->Get<LoDTensor>();
if (std::type_index(tensor.type()) == #define PrintLoDTensorCallback(cpp_type, proto_type) \
std::type_index(typeid(platform::float16))) { do { \
print_lod_tensor<platform::float16>(var_name, tensor); if (tensor.type() == proto_type) { \
} else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { print_lod_tensor<cpp_type>(var_name, tensor); \
print_lod_tensor<float>(var_name, tensor); return; \
} else if (std::type_index(tensor.type()) == } \
std::type_index(typeid(double))) { } while (0)
print_lod_tensor<double>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { _ForEachDataType_(PrintLoDTensorCallback);
print_lod_tensor<int>(var_name, tensor); VLOG(1) << "print_fetch_var: unrecognized data type:" << tensor.type();
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int64_t))) {
print_lod_tensor<int64_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) {
print_lod_tensor<bool>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(uint8_t))) {
print_lod_tensor<uint8_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int16_t))) {
print_lod_tensor<int16_t>(var_name, tensor);
} else if (std::type_index(tensor.type()) ==
std::type_index(typeid(int8_t))) {
print_lod_tensor<int8_t>(var_name, tensor);
} else {
VLOG(1) << "print_fetch_var: unrecognized data type:"
<< tensor.type().name();
}
return;
} }
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {

@ -70,9 +70,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
// only print first ten elements // only print first ten elements
int64_t size = t.numel() < 10 ? t.numel() : 10; int64_t size = t.numel() < 10 ? t.numel() : 10;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
if (IsType<float>(t.type())) { if (t.type() == proto::VarType::FP32) {
os << t.data<float>()[i] << " "; os << t.data<float>()[i] << " ";
} else if (IsType<int64_t>(t.type())) { } else if (t.type() == proto::VarType::INT64) {
os << t.data<int64_t>()[i] << " "; os << t.data<int64_t>()[i] << " ";
} else { } else {
PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
@ -387,7 +387,7 @@ void LoDTensor::MergeLoDTensor(
PADDLE_ENFORCE(!lod_tensors.empty()); PADDLE_ENFORCE(!lod_tensors.empty());
framework::DDim new_dim = lod_tensors[0]->dims(); framework::DDim new_dim = lod_tensors[0]->dims();
std::type_index new_type = lod_tensors[0]->type(); auto new_type = lod_tensors[0]->type();
framework::DataLayout new_layout = lod_tensors[0]->layout(); framework::DataLayout new_layout = lod_tensors[0]->layout();
LoD new_lod = lod_tensors[0]->lod(); LoD new_lod = lod_tensors[0]->lod();
for (size_t i = 1; i < lod_tensors.size(); ++i) { for (size_t i = 1; i < lod_tensors.size(); ++i) {

@ -43,10 +43,9 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
proto::VarType::Type GetDataTypeOfVar(const Variable* var) { proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
return framework::ToDataType(var->Get<framework::LoDTensor>().type()); return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) { } else if (var->IsType<framework::SelectedRows>()) {
return framework::ToDataType( return var->Get<framework::SelectedRows>().value().type();
var->Get<framework::SelectedRows>().value().type());
} else { } else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows"); PADDLE_THROW("Var should be LoDTensor or SelectedRows");
} }
@ -93,13 +92,13 @@ static std::string GetDtype(const Scope& scope, const std::string& name) {
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return ""; return "";
} }
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
auto tensor = var->Get<SelectedRows>().value(); auto tensor = var->Get<SelectedRows>().value();
if (UNLIKELY(!tensor.IsInitialized())) { if (UNLIKELY(!tensor.IsInitialized())) {
return "uninited"; return "uninited";
} else { } else {
return DataTypeToString(ToDataType(tensor.type())); return DataTypeToString(tensor.type());
} }
} else { } else {
return ""; return "";
@ -686,7 +685,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) { if (tensor.type() != proto::VarType::FP32 &&
tensor.type() != proto::VarType::FP64) {
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
@ -879,7 +879,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(t->type());
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
"DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)", "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",

@ -218,11 +218,11 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
if (index < 0) { if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0"; VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorFillVisitor(value, i * value_width, value_width, 0.0)); TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else { } else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), value_->type(),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }

@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
extern size_t SizeOfType(std::type_index type); extern size_t SizeOfType(proto::VarType::Type type);
void Tensor::check_memory_size() const { void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first."); holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
@ -31,7 +31,7 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type, void* Tensor::mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr, memory::Allocator::Attr attr,
size_t requested_size) { size_t requested_size) {
type_ = type; type_ = type;

@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/fluid/framework/framework.pb.h>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <memory> #include <memory>
@ -67,7 +68,7 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : type_(typeid(float)), offset_(0) {} Tensor() : type_(proto::VarType::FP32), offset_(0) {}
/*! Return a pointer to mutable memory block. */ /*! Return a pointer to mutable memory block. */
template <typename T> template <typename T>
@ -88,7 +89,7 @@ class Tensor {
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type, void* mutable_data(platform::Place place, proto::VarType::Type type,
memory::Allocator::Attr attr = memory::Allocator::kDefault, memory::Allocator::Attr attr = memory::Allocator::kDefault,
size_t requested_size = 0); size_t requested_size = 0);
@ -138,7 +139,7 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { proto::VarType::Type type() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor not initialized yet when Tensor::type() is called."); holder_, "Tensor not initialized yet when Tensor::type() is called.");
return type_; return type_;
@ -161,7 +162,7 @@ class Tensor {
private: private:
/*! holds the memory block if allocated. */ /*! holds the memory block if allocated. */
std::shared_ptr<memory::Allocation> holder_; std::shared_ptr<memory::Allocation> holder_;
std::type_index type_; proto::VarType::Type type_;
/** /**
* @brief points to elements dimensions. * @brief points to elements dimensions.
* *

@ -24,9 +24,8 @@ template <typename T>
inline const T* Tensor::data() const { inline const T* Tensor::data() const {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %d", type_);
type_.name());
return reinterpret_cast<const T*>( return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
@ -38,9 +37,8 @@ template <typename T>
inline T* Tensor::data() { inline T* Tensor::data() {
check_memory_size(); check_memory_size();
bool valid = bool valid =
std::is_same<T, void>::value || type_ == std::type_index(typeid(T)); std::is_same<T, void>::value || type_ == DataTypeTrait<T>::DataType;
PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", PADDLE_ENFORCE(valid, "Tensor holds the wrong type, it holds %s", type_);
type_.name());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_); offset_);
} }
@ -60,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place,
size_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>( return reinterpret_cast<T*>(
mutable_data(place, typeid(T), attr, requested_size)); mutable_data(place, DataTypeTrait<T>::DataType, attr, requested_size));
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {

@ -186,8 +186,8 @@ struct AnyDTypeVisitor {
template <typename Predicate, typename DevCtx> template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) { const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>( VisitDataType(tensor.type(), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out)); predicate, tensor, ctx, out));
} }
template <typename Predicate> template <typename Predicate>
@ -379,7 +379,7 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
// int32_t size // int32_t size
// void* protobuf message // void* protobuf message
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type())); desc.set_data_type(tensor.type());
auto dims = framework::vectorize(tensor.dims()); auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims(); auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0); pb_dims->Resize(static_cast<int>(dims.size()), 0);
@ -461,9 +461,7 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
void* buf; void* buf;
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
size_t size = size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
tensor->numel() *
framework::SizeOfType(framework::ToTypeIndex(desc.data_type()));
if (platform::is_gpu_place(dev_ctx.GetPlace())) { if (platform::is_gpu_place(dev_ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
Tensor cpu_tensor; Tensor cpu_tensor;

@ -289,10 +289,10 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::proto::VarType::INT64) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {

@ -266,10 +266,10 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetchs_[idx]->Input("X")[0];
if (type == typeid(float)) { if (type == framework::DataTypeTrait<float>::DataType) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
} else if (type == typeid(int64_t)) { } else if (type == framework::DataTypeTrait<int64_t>::DataType) {
GetFetchOne<int64_t>(fetch, output); GetFetchOne<int64_t>(fetch, output);
output->dtype = PaddleDType::INT64; output->dtype = PaddleDType::INT64;
} else { } else {

@ -36,10 +36,10 @@ namespace paddle {
PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) { PaddleTensor LodTensorToPaddleTensor(framework::LoDTensor* t) {
PaddleTensor pt; PaddleTensor pt;
if (t->type() == typeid(int64_t)) { if (t->type() == framework::proto::VarType::INT64) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(int64_t));
pt.dtype = PaddleDType::INT64; pt.dtype = PaddleDType::INT64;
} else if (t->type() == typeid(float)) { } else if (t->type() == framework::proto::VarType::INT32) {
pt.data.Reset(t->data<void>(), t->numel() * sizeof(float)); pt.data.Reset(t->data<void>(), t->numel() * sizeof(float));
pt.dtype = PaddleDType::FLOAT32; pt.dtype = PaddleDType::FLOAT32;
} else { } else {

@ -78,7 +78,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
auto data_type = framework::ToDataType(ctx.Input<Tensor>("Theta")->type()); auto data_type = ctx.Input<Tensor>("Theta")->type();
return framework::OpKernelType(data_type, ctx.GetPlace(), return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library); framework::DataLayout::kAnyLayout, library);
} }
@ -188,9 +188,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
framework::ToDataType(ctx.Input<Tensor>("Theta")->type()), ctx.GetPlace(),
ctx.GetPlace(), framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };

@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);

@ -25,7 +25,5 @@ REGISTER_OP_CUDA_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
size_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
uint8_t>); uint8_t>);

@ -28,6 +28,5 @@ REGISTER_OP_CPU_KERNEL(
int32_t>, int32_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
int16_t>, int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>); uint8_t>);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save