|
|
|
@ -22,46 +22,59 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct DataTypeTrait {};
|
|
|
|
|
|
|
|
|
|
// Stub handle for void
|
|
|
|
|
template <>
|
|
|
|
|
struct DataTypeTrait<void> {
|
|
|
|
|
constexpr static auto DataType = proto::VarType::RAW;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
|
|
|
|
|
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
|
|
|
|
|
|
|
|
|
|
#define _ForEachDataType_(callback) \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, float, FP32); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, double, FP64); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, int, INT32); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, bool, BOOL); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
|
|
|
|
|
_ForEachDataTypeHelper_(callback, int8_t, INT8)
|
|
|
|
|
|
|
|
|
|
#define DefineDataTypeTrait(cpp_type, proto_type) \
|
|
|
|
|
template <> \
|
|
|
|
|
struct DataTypeTrait<cpp_type> { \
|
|
|
|
|
constexpr static auto DataType = proto_type; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_ForEachDataType_(DefineDataTypeTrait);
|
|
|
|
|
|
|
|
|
|
#undef DefineDataTypeTrait
|
|
|
|
|
|
|
|
|
|
extern proto::VarType::Type ToDataType(std::type_index type);
|
|
|
|
|
extern std::type_index ToTypeIndex(proto::VarType::Type type);
|
|
|
|
|
|
|
|
|
|
template <typename Visitor>
|
|
|
|
|
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
visitor.template apply<platform::float16>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::FP32:
|
|
|
|
|
visitor.template apply<float>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::FP64:
|
|
|
|
|
visitor.template apply<double>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT32:
|
|
|
|
|
visitor.template apply<int>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT64:
|
|
|
|
|
visitor.template apply<int64_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::BOOL:
|
|
|
|
|
visitor.template apply<bool>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::UINT8:
|
|
|
|
|
visitor.template apply<uint8_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT16:
|
|
|
|
|
visitor.template apply<int16_t>();
|
|
|
|
|
break;
|
|
|
|
|
case proto::VarType::INT8:
|
|
|
|
|
visitor.template apply<int8_t>();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Not supported %d", type);
|
|
|
|
|
}
|
|
|
|
|
#define VisitDataTypeCallback(cpp_type, proto_type) \
|
|
|
|
|
do { \
|
|
|
|
|
if (type == proto_type) { \
|
|
|
|
|
visitor.template apply<cpp_type>(); \
|
|
|
|
|
return; \
|
|
|
|
|
} \
|
|
|
|
|
} while (0)
|
|
|
|
|
|
|
|
|
|
_ForEachDataType_(VisitDataTypeCallback);
|
|
|
|
|
#undef VisitDataTypeCallback
|
|
|
|
|
PADDLE_THROW("Not supported %d", type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
extern std::string DataTypeToString(const proto::VarType::Type type);
|
|
|
|
|
extern size_t SizeOfType(std::type_index type);
|
|
|
|
|
extern size_t SizeOfType(proto::VarType::Type type);
|
|
|
|
|
inline std::ostream& operator<<(std::ostream& out,
|
|
|
|
|
const proto::VarType::Type& type) {
|
|
|
|
|
out << DataTypeToString(type);
|
|
|
|
|