@ -18,6 +18,8 @@ limitations under the License. */
# include "paddle/fluid/framework/framework.pb.h"
# include "paddle/fluid/platform/bfloat16.h"
# include "paddle/fluid/platform/complex128.h"
# include "paddle/fluid/platform/complex64.h"
# include "paddle/fluid/platform/enforce.h"
# include "paddle/fluid/platform/float16.h"
@ -25,6 +27,8 @@ namespace paddle {
namespace platform {
struct bfloat16 ;
struct float16 ;
struct complex64 ;
struct complex128 ;
} // namespace platform
} // namespace paddle
@ -45,23 +49,27 @@ struct DataTypeTrait<void> {
# define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback ( cpp_type , : : paddle : : framework : : proto : : VarType : : proto_type ) ;
# define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_ ( callback , float , FP32 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : float16 , FP16 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : bfloat16 , BF16 ) ; \
_ForEachDataTypeHelper_ ( callback , double , FP64 ) ; \
_ForEachDataTypeHelper_ ( callback , int , INT32 ) ; \
_ForEachDataTypeHelper_ ( callback , int64_t , INT64 ) ; \
_ForEachDataTypeHelper_ ( callback , bool , BOOL ) ; \
_ForEachDataTypeHelper_ ( callback , uint8_t , UINT8 ) ; \
_ForEachDataTypeHelper_ ( callback , int16_t , INT16 ) ; \
_ForEachDataTypeHelper_ ( callback , int8_t , INT8 )
# define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_ ( callback , float , FP32 ) ; \
_ForEachDataTypeHelper_ ( callback , double , FP64 ) ; \
_ForEachDataTypeHelper_ ( callback , int , INT32 ) ; \
_ForEachDataTypeHelper_ ( callback , int64_t , INT64 ) ;
# define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_ ( callback , float , FP32 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : float16 , FP16 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : bfloat16 , BF16 ) ; \
_ForEachDataTypeHelper_ ( callback , double , FP64 ) ; \
_ForEachDataTypeHelper_ ( callback , int , INT32 ) ; \
_ForEachDataTypeHelper_ ( callback , int64_t , INT64 ) ; \
_ForEachDataTypeHelper_ ( callback , bool , BOOL ) ; \
_ForEachDataTypeHelper_ ( callback , uint8_t , UINT8 ) ; \
_ForEachDataTypeHelper_ ( callback , int16_t , INT16 ) ; \
_ForEachDataTypeHelper_ ( callback , int8_t , INT8 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : complex64 , COMPLEX64 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : complex128 , COMPLEX128 ) ;
# define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_ ( callback , float , FP32 ) ; \
_ForEachDataTypeHelper_ ( callback , double , FP64 ) ; \
_ForEachDataTypeHelper_ ( callback , int , INT32 ) ; \
_ForEachDataTypeHelper_ ( callback , int64_t , INT64 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : complex64 , COMPLEX64 ) ; \
_ForEachDataTypeHelper_ ( callback , : : paddle : : platform : : complex128 , COMPLEX128 ) ;
// For the use of thrust, as index-type elements can be only integers.
# define _ForEachDataTypeTiny_(callback) \