|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <typeindex>
|
|
|
|
|
#include "paddle/fluid/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
@ -22,18 +23,21 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
inline proto::VarType::Type ToDataType(std::type_index type) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
if (typeid(platform::float16).hash_code() == type.hash_code()) {
|
|
|
|
|
return proto::VarType::FP16;
|
|
|
|
|
} else if (typeid(float).hash_code() == type.hash_code()) {
|
|
|
|
|
} else if (typeid(const float).hash_code() == type.hash_code()) {
|
|
|
|
|
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
|
|
|
|
|
// One fix to this is to replace float with const float because
|
|
|
|
|
// typeid(T) == typeid(const T)
|
|
|
|
|
// http://en.cppreference.com/w/cpp/language/typeid
|
|
|
|
|
return proto::VarType::FP32;
|
|
|
|
|
} else if (typeid(double).hash_code() == type.hash_code()) {
|
|
|
|
|
} else if (typeid(const double).hash_code() == type.hash_code()) {
|
|
|
|
|
return proto::VarType::FP64;
|
|
|
|
|
} else if (typeid(int).hash_code() == type.hash_code()) {
|
|
|
|
|
} else if (typeid(const int).hash_code() == type.hash_code()) {
|
|
|
|
|
return proto::VarType::INT32;
|
|
|
|
|
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
|
|
|
|
|
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
|
|
|
|
|
return proto::VarType::INT64;
|
|
|
|
|
} else if (typeid(bool).hash_code() == type.hash_code()) {
|
|
|
|
|
} else if (typeid(const bool).hash_code() == type.hash_code()) {
|
|
|
|
|
return proto::VarType::BOOL;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Not supported");
|
|
|
|
@ -41,7 +45,6 @@ inline proto::VarType::Type ToDataType(std::type_index type) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::type_index ToTypeIndex(proto::VarType::Type type) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
switch (type) {
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
return typeid(platform::float16);
|
|
|
|
@ -62,7 +65,6 @@ inline std::type_index ToTypeIndex(proto::VarType::Type type) {
|
|
|
|
|
|
|
|
|
|
template <typename Visitor>
|
|
|
|
|
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
switch (type) {
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
visitor.template operator()<platform::float16>();
|
|
|
|
@ -88,7 +90,6 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline std::string DataTypeToString(const proto::VarType::Type type) {
|
|
|
|
|
using namespace paddle::framework::proto;
|
|
|
|
|
switch (type) {
|
|
|
|
|
case proto::VarType::FP16:
|
|
|
|
|
return "float16";
|
|
|
|
|