|
|
|
@ -128,7 +128,9 @@ class DataType {
|
|
|
|
|
// @tparam T
|
|
|
|
|
// @return true or false
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool IsCompatible() const;
|
|
|
|
|
bool IsCompatible() const {
|
|
|
|
|
return type_ == FromCType<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// returns true if the template type is the same as the Tensor type_
|
|
|
|
|
// @tparam T
|
|
|
|
@ -146,6 +148,9 @@ class DataType {
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static DataType FromCType();
|
|
|
|
|
|
|
|
|
|
// Convert from DataType to Pybind type
|
|
|
|
|
// @return
|
|
|
|
|
py::dtype AsNumpyType() const;
|
|
|
|
@ -191,68 +196,68 @@ class DataType {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<bool>() const {
|
|
|
|
|
return type_ == DataType::DE_BOOL;
|
|
|
|
|
inline DataType DataType::FromCType<bool>() {
|
|
|
|
|
return DataType(DataType::DE_BOOL);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<double>() const {
|
|
|
|
|
return type_ == DataType::DE_FLOAT64;
|
|
|
|
|
inline DataType DataType::FromCType<double>() {
|
|
|
|
|
return DataType(DataType::DE_FLOAT64);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<float>() const {
|
|
|
|
|
return type_ == DataType::DE_FLOAT32;
|
|
|
|
|
inline DataType DataType::FromCType<float>() {
|
|
|
|
|
return DataType(DataType::DE_FLOAT32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<float16>() const {
|
|
|
|
|
return type_ == DataType::DE_FLOAT16;
|
|
|
|
|
inline DataType DataType::FromCType<float16>() {
|
|
|
|
|
return DataType(DataType::DE_FLOAT16);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<int64_t>() const {
|
|
|
|
|
return type_ == DataType::DE_INT64;
|
|
|
|
|
inline DataType DataType::FromCType<int64_t>() {
|
|
|
|
|
return DataType(DataType::DE_INT64);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<uint64_t>() const {
|
|
|
|
|
return type_ == DataType::DE_UINT64;
|
|
|
|
|
inline DataType DataType::FromCType<uint64_t>() {
|
|
|
|
|
return DataType(DataType::DE_UINT64);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<int32_t>() const {
|
|
|
|
|
return type_ == DataType::DE_INT32;
|
|
|
|
|
inline DataType DataType::FromCType<int32_t>() {
|
|
|
|
|
return DataType(DataType::DE_INT32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<uint32_t>() const {
|
|
|
|
|
return type_ == DataType::DE_UINT32;
|
|
|
|
|
inline DataType DataType::FromCType<uint32_t>() {
|
|
|
|
|
return DataType(DataType::DE_UINT32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<int16_t>() const {
|
|
|
|
|
return type_ == DataType::DE_INT16;
|
|
|
|
|
inline DataType DataType::FromCType<int16_t>() {
|
|
|
|
|
return DataType(DataType::DE_INT16);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<uint16_t>() const {
|
|
|
|
|
return type_ == DataType::DE_UINT16;
|
|
|
|
|
inline DataType DataType::FromCType<uint16_t>() {
|
|
|
|
|
return DataType(DataType::DE_UINT16);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<int8_t>() const {
|
|
|
|
|
return type_ == DataType::DE_INT8;
|
|
|
|
|
inline DataType DataType::FromCType<int8_t>() {
|
|
|
|
|
return DataType(DataType::DE_INT8);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<uint8_t>() const {
|
|
|
|
|
return type_ == DataType::DE_UINT8;
|
|
|
|
|
inline DataType DataType::FromCType<uint8_t>() {
|
|
|
|
|
return DataType(DataType::DE_UINT8);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline bool DataType::IsCompatible<std::string_view>() const {
|
|
|
|
|
return type_ == DataType::DE_STRING;
|
|
|
|
|
inline DataType DataType::FromCType<std::string_view>() {
|
|
|
|
|
return DataType(DataType::DE_STRING);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|