@ -159,17 +159,10 @@ DataType Tensor::type() const {
return DataType : : UINT8 ;
} else if ( type = = framework : : proto : : VarType : : FP64 ) {
return DataType : : FLOAT64 ;
} else if ( type = = framework : : proto : : VarType : : BF16 ) {
return DataType : : BFLOAT16 ;
} else if ( type = = framework : : proto : : VarType : : FP16 ) {
return DataType : : FLOAT16 ;
} else if ( type = = framework : : proto : : VarType : : COMPLEX64 ) {
return DataType : : COMPLEX64 ;
} else if ( type = = framework : : proto : : VarType : : COMPLEX128 ) {
return DataType : : COMPLEX128 ;
} else if ( type = = framework : : proto : : VarType : : BOOL ) {
return DataType : : BOOL ;
}
// TODO(JiabinYang) Support more dtype here
return DataType : : FLOAT32 ;
}
@ -207,14 +200,6 @@ Tensor Tensor::copy_to(const PlaceType &target_place) const {
return target ;
}
template PD_DLL_DECL Tensor
Tensor : : copy_to < paddle : : platform : : float16 > ( const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor Tensor : : copy_to < paddle : : platform : : bfloat16 > (
const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor Tensor : : copy_to < paddle : : platform : : complex64 > (
const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor Tensor : : copy_to < paddle : : platform : : complex128 > (
const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor
Tensor : : copy_to < float > ( const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor
@ -238,14 +223,6 @@ template PD_DLL_DECL int64_t *Tensor::data<int64_t>() const;
template PD_DLL_DECL int32_t * Tensor : : data < int32_t > ( ) const ;
template PD_DLL_DECL uint8_t * Tensor : : data < uint8_t > ( ) const ;
template PD_DLL_DECL int8_t * Tensor : : data < int8_t > ( ) const ;
template PD_DLL_DECL paddle : : platform : : float16 *
Tensor : : data < paddle : : platform : : float16 > ( ) const ;
template PD_DLL_DECL paddle : : platform : : bfloat16 *
Tensor : : data < paddle : : platform : : bfloat16 > ( ) const ;
template PD_DLL_DECL paddle : : platform : : complex128 *
Tensor : : data < paddle : : platform : : complex128 > ( ) const ;
template PD_DLL_DECL paddle : : platform : : complex64 *
Tensor : : data < paddle : : platform : : complex64 > ( ) const ;
template PD_DLL_DECL int16_t * Tensor : : data < int16_t > ( ) const ;
template PD_DLL_DECL bool * Tensor : : data < bool > ( ) const ;
@ -255,14 +232,6 @@ template PD_DLL_DECL int64_t *Tensor::mutable_data<int64_t>();
template PD_DLL_DECL int32_t * Tensor : : mutable_data < int32_t > ( ) ;
template PD_DLL_DECL uint8_t * Tensor : : mutable_data < uint8_t > ( ) ;
template PD_DLL_DECL int8_t * Tensor : : mutable_data < int8_t > ( ) ;
template PD_DLL_DECL paddle : : platform : : float16 *
Tensor : : mutable_data < paddle : : platform : : float16 > ( ) ;
template PD_DLL_DECL paddle : : platform : : bfloat16 *
Tensor : : mutable_data < paddle : : platform : : bfloat16 > ( ) ;
template PD_DLL_DECL paddle : : platform : : complex128 *
Tensor : : mutable_data < paddle : : platform : : complex128 > ( ) ;
template PD_DLL_DECL paddle : : platform : : complex64 *
Tensor : : mutable_data < paddle : : platform : : complex64 > ( ) ;
template PD_DLL_DECL int16_t * Tensor : : mutable_data < int16_t > ( ) ;
template PD_DLL_DECL bool * Tensor : : mutable_data < bool > ( ) ;
@ -277,14 +246,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>(
const PlaceType & place ) ;
template PD_DLL_DECL int8_t * Tensor : : mutable_data < int8_t > (
const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : float16 *
Tensor : : mutable_data < paddle : : platform : : float16 > ( const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : bfloat16 *
Tensor : : mutable_data < paddle : : platform : : bfloat16 > ( const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : complex128 *
Tensor : : mutable_data < paddle : : platform : : complex128 > ( const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : complex64 *
Tensor : : mutable_data < paddle : : platform : : complex64 > ( const PlaceType & place ) ;
template PD_DLL_DECL int16_t * Tensor : : mutable_data < int16_t > (
const PlaceType & place ) ;
template PD_DLL_DECL bool * Tensor : : mutable_data < bool > ( const PlaceType & place ) ;
@ -320,14 +281,6 @@ Tensor Tensor::cast(const DataType &target_type) const {
auto dst_type =
framework : : CustomTensorUtils : : ConvertEnumDTypeToInnerDType ( target_type ) ;
switch ( src_type ) {
case framework : : proto : : VarType : : FP16 :
framework : : VisitDataType (
dst_type , CastDataType < platform : : float16 > ( * tensor , rlt_tensor_ , ctx ) ) ;
break ;
case framework : : proto : : VarType : : BF16 :
framework : : VisitDataType ( dst_type , CastDataType < platform : : bfloat16 > (
* tensor , rlt_tensor_ , ctx ) ) ;
break ;
case framework : : proto : : VarType : : FP32 :
framework : : VisitDataType ( dst_type ,
CastDataType < float > ( * tensor , rlt_tensor_ , ctx ) ) ;
@ -356,14 +309,7 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework : : VisitDataType (
dst_type , CastDataType < uint8_t > ( * tensor , rlt_tensor_ , ctx ) ) ;
break ;
case framework : : proto : : VarType : : COMPLEX64 :
framework : : VisitDataType ( dst_type , CastDataType < platform : : complex64 > (
* tensor , rlt_tensor_ , ctx ) ) ;
break ;
case framework : : proto : : VarType : : COMPLEX128 :
framework : : VisitDataType ( dst_type , CastDataType < platform : : complex128 > (
* tensor , rlt_tensor_ , ctx ) ) ;
break ;
// TODO(JiabinYang) Support more dtype here
default :
PADDLE_THROW ( platform : : errors : : Unimplemented (
" Data type (%s) is not supported when casting data type. " ,