@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License . */
limitations under the License . */
# include "paddle/fluid/extension/include/ext_tensor.h"
# include "paddle/fluid/extension/include/ext_tensor.h"
# include <utility>
# include <utility>
# include "paddle/fluid/framework/custom_tensor_utils.h"
# include "paddle/fluid/framework/custom_tensor_utils.h"
# include "paddle/fluid/framework/lod_tensor.h"
# include "paddle/fluid/framework/lod_tensor.h"
# include "paddle/fluid/memory/memcpy.h"
# include "paddle/fluid/memory/memcpy.h"
# include "paddle/fluid/platform/complex128.h"
# include "paddle/fluid/platform/complex64.h"
# include "paddle/fluid/platform/enforce.h"
# include "paddle/fluid/platform/enforce.h"
# include "paddle/fluid/platform/float16.h"
# include "paddle/fluid/platform/transform.h"
# include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace paddle {
@ -102,32 +97,13 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
void Tensor : : reshape ( const std : : vector < int64_t > & shape ) {
void Tensor : : reshape ( const std : : vector < int64_t > & shape ) {
GET_CASTED_TENSOR
GET_CASTED_TENSOR
auto new_dim = framework : : make_ddim ( shape ) ;
tensor - > Resize ( framework : : make_ddim ( shape ) ) ;
if ( tensor - > numel ( ) ! = framework : : product ( new_dim ) ) {
LOG ( WARNING ) < < " Custom Op: Calling reshape to a new shape which is bigger "
" or smaller "
< < " than original shape will not change your tensor's memory "
" Please call "
< < " paddle::Tensor::mutable_data<T>() after to reallocate "
" your tensor's size. "
< < std : : endl ;
}
tensor - > Resize ( new_dim ) ;
}
}
Tensor : : Tensor ( const PlaceType & place )
Tensor : : Tensor ( const PlaceType & place )
: tensor_ ( std : : make_shared < framework : : LoDTensor > ( ) ) ,
: tensor_ ( std : : make_shared < framework : : LoDTensor > ( ) ) ,
place_ ( place ) ,
place_ ( place ) ,
stream_ ( StreamWrapper ( ) ) { }
stream_ ( StreamWrapper ( ) ) { }
Tensor : : Tensor ( const PlaceType & place , const std : : vector < int64_t > & shape )
: tensor_ ( std : : make_shared < framework : : LoDTensor > ( ) ) ,
place_ ( place ) ,
stream_ ( StreamWrapper ( ) ) {
GET_CASTED_TENSOR
tensor - > Resize ( framework : : make_ddim ( shape ) ) ;
}
template < typename T >
template < typename T >
T * Tensor : : mutable_data ( const PlaceType & place ) {
T * Tensor : : mutable_data ( const PlaceType & place ) {
place_ = place ;
place_ = place ;
@ -186,12 +162,6 @@ DataType Tensor::type() const {
return DataType : : FLOAT64 ;
return DataType : : FLOAT64 ;
} else if ( type = = framework : : proto : : VarType : : BOOL ) {
} else if ( type = = framework : : proto : : VarType : : BOOL ) {
return DataType : : BOOL ;
return DataType : : BOOL ;
} else if ( type = = framework : : proto : : VarType : : COMPLEX64 ) {
return DataType : : COMPLEX64 ;
} else if ( type = = framework : : proto : : VarType : : COMPLEX128 ) {
return DataType : : COMPLEX128 ;
} else if ( type = = framework : : proto : : VarType : : FP16 ) {
return DataType : : FLOAT16 ;
}
}
// TODO(JiabinYang) Support more dtype here
// TODO(JiabinYang) Support more dtype here
return DataType : : FLOAT32 ;
return DataType : : FLOAT32 ;
@ -247,12 +217,6 @@ template PD_DLL_DECL Tensor
Tensor : : copy_to < int16_t > ( const PlaceType & target_place ) const ;
Tensor : : copy_to < int16_t > ( const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor
template PD_DLL_DECL Tensor
Tensor : : copy_to < bool > ( const PlaceType & target_place ) const ;
Tensor : : copy_to < bool > ( const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor Tensor : : copy_to < paddle : : platform : : complex64 > (
const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor Tensor : : copy_to < paddle : : platform : : complex128 > (
const PlaceType & target_place ) const ;
template PD_DLL_DECL Tensor
Tensor : : copy_to < paddle : : platform : : float16 > ( const PlaceType & target_place ) const ;
template PD_DLL_DECL float * Tensor : : data < float > ( ) const ;
template PD_DLL_DECL float * Tensor : : data < float > ( ) const ;
template PD_DLL_DECL double * Tensor : : data < double > ( ) const ;
template PD_DLL_DECL double * Tensor : : data < double > ( ) const ;
@ -262,12 +226,6 @@ template PD_DLL_DECL uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL int8_t * Tensor : : data < int8_t > ( ) const ;
template PD_DLL_DECL int8_t * Tensor : : data < int8_t > ( ) const ;
template PD_DLL_DECL int16_t * Tensor : : data < int16_t > ( ) const ;
template PD_DLL_DECL int16_t * Tensor : : data < int16_t > ( ) const ;
template PD_DLL_DECL bool * Tensor : : data < bool > ( ) const ;
template PD_DLL_DECL bool * Tensor : : data < bool > ( ) const ;
template PD_DLL_DECL paddle : : platform : : complex64 *
Tensor : : data < paddle : : platform : : complex64 > ( ) const ;
template PD_DLL_DECL paddle : : platform : : complex128 *
Tensor : : data < paddle : : platform : : complex128 > ( ) const ;
template PD_DLL_DECL paddle : : platform : : float16 *
Tensor : : data < paddle : : platform : : float16 > ( ) const ;
template PD_DLL_DECL float * Tensor : : mutable_data < float > ( ) ;
template PD_DLL_DECL float * Tensor : : mutable_data < float > ( ) ;
template PD_DLL_DECL double * Tensor : : mutable_data < double > ( ) ;
template PD_DLL_DECL double * Tensor : : mutable_data < double > ( ) ;
@ -277,12 +235,6 @@ template PD_DLL_DECL uint8_t *Tensor::mutable_data<uint8_t>();
template PD_DLL_DECL int8_t * Tensor : : mutable_data < int8_t > ( ) ;
template PD_DLL_DECL int8_t * Tensor : : mutable_data < int8_t > ( ) ;
template PD_DLL_DECL int16_t * Tensor : : mutable_data < int16_t > ( ) ;
template PD_DLL_DECL int16_t * Tensor : : mutable_data < int16_t > ( ) ;
template PD_DLL_DECL bool * Tensor : : mutable_data < bool > ( ) ;
template PD_DLL_DECL bool * Tensor : : mutable_data < bool > ( ) ;
template PD_DLL_DECL paddle : : platform : : complex64 *
Tensor : : mutable_data < paddle : : platform : : complex64 > ( ) ;
template PD_DLL_DECL paddle : : platform : : complex128 *
Tensor : : mutable_data < paddle : : platform : : complex128 > ( ) ;
template PD_DLL_DECL paddle : : platform : : float16 *
Tensor : : mutable_data < paddle : : platform : : float16 > ( ) ;
template PD_DLL_DECL float * Tensor : : mutable_data < float > ( const PlaceType & place ) ;
template PD_DLL_DECL float * Tensor : : mutable_data < float > ( const PlaceType & place ) ;
template PD_DLL_DECL double * Tensor : : mutable_data < double > (
template PD_DLL_DECL double * Tensor : : mutable_data < double > (
@ -298,12 +250,6 @@ template PD_DLL_DECL int8_t *Tensor::mutable_data<int8_t>(
template PD_DLL_DECL int16_t * Tensor : : mutable_data < int16_t > (
template PD_DLL_DECL int16_t * Tensor : : mutable_data < int16_t > (
const PlaceType & place ) ;
const PlaceType & place ) ;
template PD_DLL_DECL bool * Tensor : : mutable_data < bool > ( const PlaceType & place ) ;
template PD_DLL_DECL bool * Tensor : : mutable_data < bool > ( const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : complex64 *
Tensor : : mutable_data < paddle : : platform : : complex64 > ( const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : complex128 *
Tensor : : mutable_data < paddle : : platform : : complex128 > ( const PlaceType & place ) ;
template PD_DLL_DECL paddle : : platform : : float16 *
Tensor : : mutable_data < paddle : : platform : : float16 > ( const PlaceType & place ) ;
std : : vector < int64_t > Tensor : : shape ( ) const {
std : : vector < int64_t > Tensor : : shape ( ) const {
GET_CASTED_TENSOR
GET_CASTED_TENSOR
@ -364,21 +310,6 @@ Tensor Tensor::cast(const DataType &target_type) const {
framework : : VisitDataType (
framework : : VisitDataType (
dst_type , CastDataType < uint8_t > ( * tensor , rlt_tensor_ , ctx ) ) ;
dst_type , CastDataType < uint8_t > ( * tensor , rlt_tensor_ , ctx ) ) ;
break ;
break ;
case framework : : proto : : VarType : : COMPLEX64 :
framework : : VisitDataType (
dst_type ,
CastDataType < paddle : : platform : : complex64 > ( * tensor , rlt_tensor_ , ctx ) ) ;
break ;
case framework : : proto : : VarType : : COMPLEX128 :
framework : : VisitDataType ( dst_type ,
CastDataType < paddle : : platform : : complex128 > (
* tensor , rlt_tensor_ , ctx ) ) ;
break ;
case framework : : proto : : VarType : : FP16 :
framework : : VisitDataType (
dst_type ,
CastDataType < paddle : : platform : : float16 > ( * tensor , rlt_tensor_ , ctx ) ) ;
break ;
// TODO(JiabinYang) Support more dtype here
// TODO(JiabinYang) Support more dtype here
default :
default :
PADDLE_THROW ( platform : : errors : : Unimplemented (
PADDLE_THROW ( platform : : errors : : Unimplemented (