|
|
|
@ -26,75 +26,40 @@ namespace framework {
|
|
|
|
|
extern proto::VarType::Type ToDataType(std::type_index type);
|
|
|
|
|
extern std::type_index ToTypeIndex(proto::VarType::Type type);
|
|
|
|
|
|
|
|
|
|
#if !defined(_WIN32)
|
|
|
|
|
template <typename Visitor>
|
|
|
|
|
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
visitor.template operator()<platform::float16>();
|
|
|
|
|
visitor.template apply<platform::float16>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::FP32:
|
|
|
|
|
visitor.template operator()<float>();
|
|
|
|
|
visitor.template apply<float>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::FP64:
|
|
|
|
|
visitor.template operator()<double>();
|
|
|
|
|
visitor.template apply<double>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT32:
|
|
|
|
|
visitor.template operator()<int>();
|
|
|
|
|
visitor.template apply<int>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT64:
|
|
|
|
|
visitor.template operator()<int64_t>();
|
|
|
|
|
visitor.template apply<int64_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::BOOL:
|
|
|
|
|
visitor.template operator()<bool>();
|
|
|
|
|
visitor.template apply<bool>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::UINT8:
|
|
|
|
|
visitor.template operator()<uint8_t>();
|
|
|
|
|
visitor.template apply<uint8_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT16:
|
|
|
|
|
visitor.template operator()<int16_t>();
|
|
|
|
|
visitor.template apply<int16_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT8:
|
|
|
|
|
visitor.template operator()<int8_t>();
|
|
|
|
|
visitor.template apply<int8_t>();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Not supported %d", type);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
// the msvc compiler do not implement two-stage name lookup correctly.
|
|
|
|
|
template <typename Visitor>
|
|
|
|
|
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
visitor.operator()<platform::float16>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::FP32:
|
|
|
|
|
visitor.operator()<float>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::FP64:
|
|
|
|
|
visitor.operator()<double>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT32:
|
|
|
|
|
visitor.operator()<int>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT64:
|
|
|
|
|
visitor.operator()<int64_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::BOOL:
|
|
|
|
|
visitor.operator()<bool>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::UINT8:
|
|
|
|
|
visitor.operator()<uint8_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT16:
|
|
|
|
|
visitor.operator()<int16_t>();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Not supported %d", type);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif // _WIN32
|
|
|
|
|
|
|
|
|
|
extern std::string DataTypeToString(const proto::VarType::Type type);
|
|
|
|
|
extern size_t SizeOfType(std::type_index type);
|
|
|
|
|