|
|
|
@ -20,35 +20,35 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
inline proto::DataType ToDataType(std::type_index type) {
|
|
|
|
|
inline proto::VarType::Type ToDataType(std::type_index type) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
if (typeid(float).hash_code() == type.hash_code()) {
|
|
|
|
|
return DataType::FP32;
|
|
|
|
|
return proto::VarType::FP32;
|
|
|
|
|
} else if (typeid(double).hash_code() == type.hash_code()) {
|
|
|
|
|
return DataType::FP64;
|
|
|
|
|
return proto::VarType::FP64;
|
|
|
|
|
} else if (typeid(int).hash_code() == type.hash_code()) {
|
|
|
|
|
return DataType::INT32;
|
|
|
|
|
return proto::VarType::INT32;
|
|
|
|
|
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
|
|
|
|
|
return DataType::INT64;
|
|
|
|
|
return proto::VarType::INT64;
|
|
|
|
|
} else if (typeid(bool).hash_code() == type.hash_code()) {
|
|
|
|
|
return DataType::BOOL;
|
|
|
|
|
return proto::VarType::BOOL;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not supported");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::type_index ToTypeIndex(proto::DataType type) {
|
|
|
|
|
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
switch (type) {
|
|
|
|
|
case DataType::FP32:
|
|
|
|
|
case proto::VarType::FP32:
|
|
|
|
|
return typeid(float);
|
|
|
|
|
case DataType::FP64:
|
|
|
|
|
case proto::VarType::FP64:
|
|
|
|
|
return typeid(double);
|
|
|
|
|
case DataType::INT32:
|
|
|
|
|
case proto::VarType::INT32:
|
|
|
|
|
return typeid(int);
|
|
|
|
|
case DataType::INT64:
|
|
|
|
|
case proto::VarType::INT64:
|
|
|
|
|
return typeid(int64_t);
|
|
|
|
|
case DataType::BOOL:
|
|
|
|
|
case proto::VarType::BOOL:
|
|
|
|
|
return typeid(bool);
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Not support type %d", type);
|
|
|
|
@ -56,22 +56,22 @@ inline std::type_index ToTypeIndex(proto::DataType type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Visitor>
|
|
|
|
|
inline void VisitDataType(proto::DataType type, Visitor visitor) {
|
|
|
|
|
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
switch (type) {
|
|
|
|
|
case DataType::FP32:
|
|
|
|
|
case proto::VarType::FP32:
|
|
|
|
|
visitor.template operator()<float>();
|
|
|
|
|
break;
|
|
|
|
|
case DataType::FP64:
|
|
|
|
|
case proto::VarType::FP64:
|
|
|
|
|
visitor.template operator()<double>();
|
|
|
|
|
break;
|
|
|
|
|
case DataType::INT32:
|
|
|
|
|
case proto::VarType::INT32:
|
|
|
|
|
visitor.template operator()<int>();
|
|
|
|
|
break;
|
|
|
|
|
case DataType::INT64:
|
|
|
|
|
case proto::VarType::INT64:
|
|
|
|
|
visitor.template operator()<int64_t>();
|
|
|
|
|
break;
|
|
|
|
|
case DataType::BOOL:
|
|
|
|
|
case proto::VarType::BOOL:
|
|
|
|
|
visitor.template operator()<bool>();
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
@ -79,22 +79,22 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::string DataTypeToString(const proto::DataType type) {
|
|
|
|
|
inline std::string DataTypeToString(const proto::VarType::Type type) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
switch (type) {
|
|
|
|
|
case DataType::FP16:
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
return "float16";
|
|
|
|
|
case DataType::FP32:
|
|
|
|
|
case proto::VarType::FP32:
|
|
|
|
|
return "float32";
|
|
|
|
|
case DataType::FP64:
|
|
|
|
|
case proto::VarType::FP64:
|
|
|
|
|
return "float64";
|
|
|
|
|
case DataType::INT16:
|
|
|
|
|
case proto::VarType::INT16:
|
|
|
|
|
return "int16";
|
|
|
|
|
case DataType::INT32:
|
|
|
|
|
case proto::VarType::INT32:
|
|
|
|
|
return "int32";
|
|
|
|
|
case DataType::INT64:
|
|
|
|
|
case proto::VarType::INT64:
|
|
|
|
|
return "int64";
|
|
|
|
|
case DataType::BOOL:
|
|
|
|
|
case proto::VarType::BOOL:
|
|
|
|
|
return "bool";
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("Not support type %d", type);
|
|
|
|
@ -102,7 +102,7 @@ inline std::string DataTypeToString(const proto::DataType type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::ostream& operator<<(std::ostream& out,
|
|
|
|
|
const proto::DataType& type) {
|
|
|
|
|
const proto::VarType::Type& type) {
|
|
|
|
|
out << DataTypeToString(type);
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|