@ -15,6 +15,7 @@ limitations under the License. */
# pragma once
# pragma once
# include <cstdint>
# include <cstdint>
# include <cstring>
# include <memory>
# include <memory>
# include <type_traits>
# include <type_traits>
# include "paddle/framework/ddim.h"
# include "paddle/framework/ddim.h"
@ -44,11 +45,17 @@ class Tensor {
typename std : : enable_if < std : : is_pod < T > : : value > : : type * = nullptr >
typename std : : enable_if < std : : is_pod < T > : : value > : : type * = nullptr >
T * mutable_data ( DDim dims , paddle : : platform : : Place place ) {
T * mutable_data ( DDim dims , paddle : : platform : : Place place ) {
dims_ = dims ;
dims_ = dims ;
return mutable_data < T > ( place ) ;
}
template < typename T , // must be POD types
typename std : : enable_if < std : : is_pod < T > : : value > : : type * = nullptr >
T * mutable_data ( paddle : : platform : : Place place ) {
if ( holder_ = = nullptr | |
if ( holder_ = = nullptr | |
! ( holder_ - > Place ( ) = =
! ( holder_ - > Place ( ) = =
place ) /* some versions of boost::variant don't have operator!= */
place ) /* some versions of boost::variant don't have operator!= */
| | holder_ - > Size ( ) < product ( dims ) * sizeof ( T ) + offset_ ) {
| | holder_ - > Size ( ) < product ( dims _ ) * sizeof ( T ) + offset_ ) {
holder_ . reset ( new PlaceholderImpl < T > ( place , product ( dims ) * sizeof ( T ) ) ) ;
holder_ . reset ( new PlaceholderImpl < T > ( place , product ( dims _ ) * sizeof ( T ) ) ) ;
offset_ = 0 ;
offset_ = 0 ;
}
}
return reinterpret_cast < T * > ( reinterpret_cast < uintptr_t > ( holder_ - > Ptr ( ) ) +
return reinterpret_cast < T * > ( reinterpret_cast < uintptr_t > ( holder_ - > Ptr ( ) ) +
@ -63,6 +70,15 @@ class Tensor {
offset_ = src . offset_ ;
offset_ = src . offset_ ;
}
}
void CopyFrom ( const Tensor & src , paddle : : platform : : Place dst_place ) {
PADDLE_ENFORCE ( src . holder_ ! = nullptr ,
" Can not copy from an uninitialized tensor. " ) ;
size_t size = product ( src . dims ( ) ) * src . holder_ - > TypeSize ( ) ;
holder_ . reset ( src . holder_ - > Clone ( src . offset_ , size , dst_place ) ) ;
dims_ = src . dims ( ) ;
offset_ = 0 ;
}
Tensor Slice ( const int & begin_idx , const int & end_idx ) const {
Tensor Slice ( const int & begin_idx , const int & end_idx ) const {
PADDLE_ENFORCE ( holder_ ! = nullptr ,
PADDLE_ENFORCE ( holder_ ! = nullptr ,
" The sliced tenosr has not been initialized. " ) ;
" The sliced tenosr has not been initialized. " ) ;
@ -95,6 +111,8 @@ class Tensor {
virtual paddle : : platform : : Place Place ( ) const = 0 ;
virtual paddle : : platform : : Place Place ( ) const = 0 ;
virtual size_t Size ( ) const = 0 ;
virtual size_t Size ( ) const = 0 ;
virtual size_t TypeSize ( ) const = 0 ;
virtual size_t TypeSize ( ) const = 0 ;
virtual Placeholder * Clone ( size_t begin , size_t size ,
paddle : : platform : : Place place ) const = 0 ;
} ;
} ;
template < typename T >
template < typename T >
@ -122,6 +140,18 @@ class Tensor {
virtual size_t Size ( ) const { return size_ ; }
virtual size_t Size ( ) const { return size_ ; }
virtual paddle : : platform : : Place Place ( ) const { return place_ ; }
virtual paddle : : platform : : Place Place ( ) const { return place_ ; }
virtual size_t TypeSize ( ) const { return sizeof ( T ) ; }
virtual size_t TypeSize ( ) const { return sizeof ( T ) ; }
// TODO: Clone only support CPU now. GPU support is needed.
virtual Placeholder * Clone ( size_t begin , size_t size ,
paddle : : platform : : Place place ) const {
PADDLE_ENFORCE ( paddle : : platform : : is_cpu_place ( place_ ) & &
paddle : : platform : : is_cpu_place ( place ) ,
" PlaceholderImpl::Clone only support CPU now. " ) ;
PlaceholderImpl < T > * dst = new PlaceholderImpl < T > ( place , size ) ;
void * begin_ptr =
reinterpret_cast < void * > ( reinterpret_cast < uintptr_t > ( Ptr ( ) ) + begin ) ;
memcpy ( dst - > Ptr ( ) , begin_ptr , size ) ;
return dst ;
}
std : : unique_ptr < T , Deleter > ptr_ ;
std : : unique_ptr < T , Deleter > ptr_ ;
paddle : : platform : : Place place_ ; // record the place of ptr_.
paddle : : platform : : Place place_ ; // record the place of ptr_.