|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include <stdint.h>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
|
@ -24,10 +23,10 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
struct DataTypeMap {
|
|
|
|
|
std::map<const char*, proto::VarType::Type> cpp_to_proto_;
|
|
|
|
|
std::unordered_map<std::type_index, proto::VarType::Type> cpp_to_proto_;
|
|
|
|
|
std::unordered_map<int, std::type_index> proto_to_cpp_;
|
|
|
|
|
std::unordered_map<int, std::string> proto_to_str_;
|
|
|
|
|
std::map<const char* /*name pointer*/, size_t> cpp_to_size_;
|
|
|
|
|
std::unordered_map<std::type_index, size_t> cpp_to_size_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static DataTypeMap* InitDataTypeMap();
|
|
|
|
@ -44,9 +43,9 @@ static inline void RegisterType(DataTypeMap* map,
|
|
|
|
|
proto::VarType::Type proto_type,
|
|
|
|
|
const std::string& name) {
|
|
|
|
|
map->proto_to_cpp_.emplace(static_cast<int>(proto_type), typeid(T));
|
|
|
|
|
map->cpp_to_proto_.emplace(typeid(T).name(), proto_type);
|
|
|
|
|
map->cpp_to_proto_.emplace(typeid(T), proto_type);
|
|
|
|
|
map->proto_to_str_.emplace(static_cast<int>(proto_type), name);
|
|
|
|
|
map->cpp_to_size_.emplace(typeid(T).name(), sizeof(T));
|
|
|
|
|
map->cpp_to_size_.emplace(typeid(T), sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static DataTypeMap* InitDataTypeMap() {
|
|
|
|
@ -72,7 +71,7 @@ static DataTypeMap* InitDataTypeMap() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type ToDataType(std::type_index type) {
|
|
|
|
|
auto it = gDataTypeMap().cpp_to_proto_.find(type.name());
|
|
|
|
|
auto it = gDataTypeMap().cpp_to_proto_.find(type);
|
|
|
|
|
if (it != gDataTypeMap().cpp_to_proto_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
@ -98,8 +97,8 @@ std::string DataTypeToString(const proto::VarType::Type type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t SizeOfType(std::type_index type) {
|
|
|
|
|
auto it = gDataTypeMap().cpp_to_size_.find(type.name());
|
|
|
|
|
if (LIKELY(it != gDataTypeMap().cpp_to_size_.end())) {
|
|
|
|
|
auto it = gDataTypeMap().cpp_to_size_.find(type);
|
|
|
|
|
if (it != gDataTypeMap().cpp_to_size_.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_THROW("Not support %s as tensor type", type.name());
|
|
|
|
|