|
|
|
@ -431,14 +431,14 @@ inline void PyCUDAPinnedTensorSetFromArray(
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
constexpr bool IsValidDTypeToPyArray() {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
|
|
|
|
|
template <> \
|
|
|
|
|
constexpr bool IsValidDTypeToPyArray<type>() { \
|
|
|
|
|
return true; \
|
|
|
|
|
struct ValidDTypeToPyArrayChecker {
|
|
|
|
|
static constexpr bool kValue = false;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
|
|
|
|
|
template <> \
|
|
|
|
|
struct ValidDTypeToPyArrayChecker<type> { \
|
|
|
|
|
static constexpr bool kValue = true; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
|
|
|
|
@ -452,15 +452,16 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t);
|
|
|
|
|
|
|
|
|
|
inline std::string TensorDTypeToPyDTypeStr(
|
|
|
|
|
framework::proto::VarType::Type type) {
|
|
|
|
|
#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \
|
|
|
|
|
if (type == proto_type) { \
|
|
|
|
|
if (std::is_same<T, platform::float16>::value) { \
|
|
|
|
|
return "e"; \
|
|
|
|
|
} else { \
|
|
|
|
|
PADDLE_ENFORCE(IsValidDTypeToPyArray<T>, \
|
|
|
|
|
"This type of tensor cannot be expose to Python"); \
|
|
|
|
|
return py::format_descriptor<T>::format(); \
|
|
|
|
|
} \
|
|
|
|
|
#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \
|
|
|
|
|
if (type == proto_type) { \
|
|
|
|
|
if (std::is_same<T, platform::float16>::value) { \
|
|
|
|
|
return "e"; \
|
|
|
|
|
} else { \
|
|
|
|
|
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
|
|
|
|
|
PADDLE_ENFORCE(kIsValidDType, \
|
|
|
|
|
"This type of tensor cannot be expose to Python"); \
|
|
|
|
|
return py::format_descriptor<T>::format(); \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE);
|
|
|
|
|