@ -16,18 +16,25 @@
# ifndef DATASET_CORE_DATA_TYPE_H_
# define DATASET_CORE_DATA_TYPE_H_
# include <opencv2/core/hal/interface.h>
# include <string>
# include "pybind11/numpy.h"
# include "pybind11/pybind11.h"
# include "dataset/core/constants.h"
# include "dataset/core/pybind_support.h"
namespace py = pybind11 ;
namespace mindspore {
namespace dataset {
// Class that represents basic data types in DataEngine.
class DataType {
public :
enum Type : uint8_t {
DE_UNKNOWN = 0 ,
DE_BOOL ,
DE_INT8 ,
DE_UINT8 ,
@ -40,20 +47,60 @@ class DataType {
DE_FLOAT16 ,
DE_FLOAT32 ,
DE_FLOAT64 ,
DE_UNKNOWN
DE_STRING ,
NUM_OF_TYPES
} ;
static constexpr uint8_t DE_BOOL_SIZE = 1 ;
static constexpr uint8_t DE_UINT8_SIZE = 1 ;
static constexpr uint8_t DE_INT8_SIZE = 1 ;
static constexpr uint8_t DE_UINT16_SIZE = 2 ;
static constexpr uint8_t DE_INT16_SIZE = 2 ;
static constexpr uint8_t DE_UINT32_SIZE = 4 ;
static constexpr uint8_t DE_INT32_SIZE = 4 ;
static constexpr uint8_t DE_INT64_SIZE = 8 ;
static constexpr uint8_t DE_UINT64_SIZE = 8 ;
static constexpr uint8_t DE_FLOAT32_SIZE = 4 ;
static constexpr uint8_t DE_FLOAT64_SIZE = 8 ;
inline static constexpr uint8_t SIZE_IN_BYTES [ ] = { 0 , // DE_UNKNOWN
1 , // DE_BOOL
1 , // DE_INT8
1 , // DE_UINT8
2 , // DE_INT16
2 , // DE_UINT16
4 , // DE_INT32
4 , // DE_UINT32
8 , // DE_INT64
8 , // DE_UINT64
2 , // DE_FLOAT16
4 , // DE_FLOAT32
8 , // DE_FLOAT64
0 } ; // DE_STRING
inline static const char * TO_STRINGS [ ] = { " unknown " , " bool " , " int8 " , " uint8 " , " int16 " , " uint16 " , " int32 " ,
" uint32 " , " int64 " , " uint64 " , " float16 " , " float32 " , " float64 " , " string " } ;
inline static const char * PYBIND_TYPES [ ] = { " object " , " bool " , " int8 " , " uint8 " , " int16 " , " uint16 " , " int32 " ,
" uint32 " , " int64 " , " uint64 " , " float16 " , " float32 " , " double " , " bytes " } ;
inline static const std : : string PYBIND_FORMAT_DESCRIPTOR [ ] = { " " , // DE_UNKNOWN
py : : format_descriptor < bool > : : format ( ) , // DE_BOOL
py : : format_descriptor < int8_t > : : format ( ) , // DE_INT8
py : : format_descriptor < uint8_t > : : format ( ) , // DE_UINT8
py : : format_descriptor < int16_t > : : format ( ) , // DE_INT16
py : : format_descriptor < uint16_t > : : format ( ) , // DE_UINT16
py : : format_descriptor < int32_t > : : format ( ) , // DE_INT32
py : : format_descriptor < uint32_t > : : format ( ) , // DE_UINT32
py : : format_descriptor < int64_t > : : format ( ) , // DE_INT64
py : : format_descriptor < uint64_t > : : format ( ) , // DE_UINT64
" e " , // DE_FLOAT16
py : : format_descriptor < float > : : format ( ) , // DE_FLOAT32
py : : format_descriptor < double > : : format ( ) , // DE_FLOAT64
" S " } ; // DE_STRING
inline static constexpr uint8_t CV_TYPES [ ] = { kCVInvalidType , // DE_UNKNOWN
CV_8U , // DE_BOOL
CV_8S , // DE_INT8
CV_8U , // DE_UINT8
CV_16S , // DE_INT16
CV_16U , // DE_UINT16
CV_32S , // DE_INT32
kCVInvalidType , // DE_UINT32
kCVInvalidType , // DE_INT64
kCVInvalidType , // DE_UINT64
CV_16F , // DE_FLOAT16
CV_32F , // DE_FLOAT32
CV_64F , // DE_FLOAT64
kCVInvalidType } ; // DE_STRING
// No arg constructor to create an unknown shape
DataType ( ) : type_ ( DE_UNKNOWN ) { }
@ -160,6 +207,8 @@ class DataType {
bool IsBool ( ) const { return type_ = = DataType : : DE_BOOL ; }
bool IsNumeric ( ) const { return type_ ! = DataType : : DE_STRING ; }
Type value ( ) const { return type_ ; }
private :
@ -226,6 +275,11 @@ inline bool DataType::IsCompatible<uint8_t>() const {
return type_ = = DataType : : DE_UINT8 ;
}
template < >
inline bool DataType : : IsCompatible < std : : string_view > ( ) const {
return type_ = = DataType : : DE_STRING ;
}
template < >
inline bool DataType : : IsLooselyCompatible < bool > ( ) const {
return type_ = = DataType : : DE_BOOL ;