|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include <stdint.h>
|
|
|
|
|
#include <mutex> // NOLINT
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
|
@ -28,20 +27,27 @@ struct DataTypeMap {
|
|
|
|
|
std::unordered_map<std::type_index, size_t> cpp_to_size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static DataTypeMap g_data_type_map_;
|
|
|
|
|
static DataTypeMap* InitDataTypeMap();
|
|
|
|
|
static DataTypeMap& gDataTypeMap() {
|
|
|
|
|
static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
|
|
|
|
|
return *g_data_type_map_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static inline void RegisterType(proto::VarType::Type proto_type,
|
|
|
|
|
const std::string &name) {
|
|
|
|
|
g_data_type_map_.proto_to_cpp_.emplace(static_cast<int>(proto_type),
|
|
|
|
|
typeid(T));
|
|
|
|
|
g_data_type_map_.cpp_to_proto_.emplace(typeid(T), proto_type);
|
|
|
|
|
g_data_type_map_.proto_to_str_.emplace(static_cast<int>(proto_type), name);
|
|
|
|
|
g_data_type_map_.cpp_to_size_.emplace(typeid(T), sizeof(T));
|
|
|
|
|
static inline void RegisterType(DataTypeMap* map,
|
|
|
|
|
proto::VarType::Type proto_type,
|
|
|
|
|
const std::string& name) {
|
|
|
|
|
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
|
|
|
|
|
map->cpp_to_proto_.emplace(typeid(T), proto_type);
|
|
|
|
|
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
|
|
|
|
|
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int RegisterAllTypes() {
|
|
|
|
|
#define RegType(cc_type, proto_type) RegisterType<cc_type>(proto_type, #cc_type)
|
|
|
|
|
static DataTypeMap* InitDataTypeMap() {
|
|
|
|
|
auto retv = new DataTypeMap();
|
|
|
|
|
|
|
|
|
|
#define RegType(cc_type, proto_type) \
|
|
|
|
|
RegisterType<cc_type>(retv, proto_type, #cc_type)
|
|
|
|
|
|
|
|
|
|
// NOTE: Add your customize type here.
|
|
|
|
|
RegType(platform::float16, proto::VarType::FP16);
|
|
|
|
@ -52,24 +58,20 @@ static int RegisterAllTypes() {
|
|
|
|
|
RegType(bool, proto::VarType::BOOL);
|
|
|
|
|
|
|
|
|
|
#undef RegType
|
|
|
|
|
return 0;
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::once_flag register_once_flag_;
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type ToDataType(std::type_index type) {
|
|
|
|
|
std::call_once(register_once_flag_, RegisterAllTypes);
|
|
|
|
|
auto it = g_data_type_map_.cpp_to_proto_.find(type);
|
|
|
|
|
if (it != g_data_type_map_.cpp_to_proto_.end()) {
|
|
|
|
|
auto it = gDataTypeMap().cpp_to_proto_.find(type);
|
|
|
|
|
if (it != gDataTypeMap().cpp_to_proto_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support %s as tensor type", type.name());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::type_index ToTypeIndex(proto::VarType::Type type) {
|
|
|
|
|
std::call_once(register_once_flag_, RegisterAllTypes);
|
|
|
|
|
auto it = g_data_type_map_.proto_to_cpp_.find(static_cast<int>(type));
|
|
|
|
|
if (it != g_data_type_map_.proto_to_cpp_.end()) {
|
|
|
|
|
auto it = gDataTypeMap().proto_to_cpp_.find(static_cast<int>(type));
|
|
|
|
|
if (it != gDataTypeMap().proto_to_cpp_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
|
|
|
|
@ -77,9 +79,8 @@ std::type_index ToTypeIndex(proto::VarType::Type type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string DataTypeToString(const proto::VarType::Type type) {
|
|
|
|
|
std::call_once(register_once_flag_, RegisterAllTypes);
|
|
|
|
|
auto it = g_data_type_map_.proto_to_str_.find(static_cast<int>(type));
|
|
|
|
|
if (it != g_data_type_map_.proto_to_str_.end()) {
|
|
|
|
|
auto it = gDataTypeMap().proto_to_str_.find(static_cast<int>(type));
|
|
|
|
|
if (it != gDataTypeMap().proto_to_str_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
|
|
|
|
@ -87,9 +88,8 @@ std::string DataTypeToString(const proto::VarType::Type type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t SizeOfType(std::type_index type) {
|
|
|
|
|
std::call_once(register_once_flag_, RegisterAllTypes);
|
|
|
|
|
auto it = g_data_type_map_.cpp_to_size_.find(type);
|
|
|
|
|
if (it != g_data_type_map_.cpp_to_size_.end()) {
|
|
|
|
|
auto it = gDataTypeMap().cpp_to_size_.find(type);
|
|
|
|
|
if (it != gDataTypeMap().cpp_to_size_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support %s as tensor type", type.name());
|
|
|
|
|