@ -14,6 +14,7 @@ limitations under the License. */
# pragma once
# include <cstdint>
# include <memory>
# include <type_traits>
# include "paddle/framework/ddim.h"
@ -26,31 +27,65 @@ namespace framework {
class Tensor {
public :
Tensor ( ) : offset_ ( 0 ) { }
Tensor ( const DDim & dims ) : dims_ ( dims ) , offset_ ( 0 ) { }
template < typename T >
const T * data ( ) const {
PADDLE_ENFORCE ( holder_ ! = nullptr ,
" Tensor::data must be called after Tensor::mutable_data. " ) ;
return static_cast < const T * > ( holder_ - > Ptr ( ) ) ;
PADDLE_ENFORCE (
holder_ ! = nullptr ,
" Tenosr has not been initialized. Call Tensor::mutable_data first. " ) ;
return reinterpret_cast < const T * > (
reinterpret_cast < uintptr_t > ( holder_ - > Ptr ( ) ) + offset_ ) ;
}
template < typename T , // must be POD types
typename std : : enable_if < std : : is_pod < T > : : value > : : type * = nullptr >
T * mutable_data ( DDim dims , paddle : : platform : : Place place ) {
dims_ = dims ;
if ( holder_ = = nullptr | |
! ( holder_ - > Place ( ) = =
place ) /* some versions of boost::variant don't have operator!= */
| | holder_ - > Size ( ) < product ( dims ) * sizeof ( T ) ) {
| | holder_ - > Size ( ) < product ( dims ) * sizeof ( T ) + offset_ ) {
holder_ . reset ( new PlaceholderImpl < T > ( place , product ( dims ) * sizeof ( T ) ) ) ;
offset_ = 0 ;
}
return static_cast < T * > ( holder_ - > Ptr ( ) ) ;
return reinterpret_cast < T * > ( reinterpret_cast < uintptr_t > ( holder_ - > Ptr ( ) ) +
offset_ ) ;
}
template < typename T , // must be POD types
typename std : : enable_if < std : : is_pod < T > : : value > : : type * = nullptr >
T * mutable_data ( DDim dims ) {
return mutable_data < T > ( dims , paddle : : platform : : get_place ( ) ) ;
void ShareDataFrom ( const Tensor & src ) {
PADDLE_ENFORCE ( src . holder_ ! = nullptr ,
" Tenosr 'src' has not been initialized. " ) ;
holder_ = src . holder_ ;
dims_ = src . dims_ ;
offset_ = src . offset_ ;
}
Tensor Slice ( const int & begin_idx , const int & end_idx ) {
PADDLE_ENFORCE ( holder_ ! = nullptr ,
" The sliced tenosr has not been initialized. " ) ;
PADDLE_ENFORCE ( begin_idx > = 0 & & end_idx < = dims_ [ 0 ] ,
" Slice index is less than zero or out of bound. " ) ;
PADDLE_ENFORCE ( begin_idx < end_idx ,
" Begin index must be less than end index. " ) ;
PADDLE_ENFORCE ( dims_ [ 0 ] ! = 1 , " Can not slice a tensor with dims_[0] = 1. " ) ;
std : : vector < int > d = vectorize ( dims_ ) ;
int base = 1 ;
for ( size_t i = 1 ; i < d . size ( ) ; + + i ) {
base * = d [ i ] ;
}
Tensor dst ;
dst . holder_ = holder_ ;
dst . dims_ = dims_ ;
dst . dims_ [ 0 ] = end_idx - begin_idx ;
dst . offset_ = offset_ + begin_idx * base * holder_ - > TypeSize ( ) ;
return dst ;
}
DDim dims ( ) const { return dims_ ; }
private :
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
@ -59,6 +94,7 @@ class Tensor {
virtual void * Ptr ( ) const = 0 ;
virtual paddle : : platform : : Place Place ( ) const = 0 ;
virtual size_t Size ( ) const = 0 ;
virtual size_t TypeSize ( ) const = 0 ;
} ;
template < typename T >
@ -85,6 +121,7 @@ class Tensor {
virtual void * Ptr ( ) const { return static_cast < void * > ( ptr_ . get ( ) ) ; }
virtual size_t Size ( ) const { return size_ ; }
virtual paddle : : platform : : Place Place ( ) const { return place_ ; }
virtual size_t TypeSize ( ) const { return sizeof ( T ) ; }
std : : unique_ptr < T , Deleter > ptr_ ;
paddle : : platform : : Place place_ ; // record the place of ptr_.
@ -92,6 +129,8 @@ class Tensor {
} ;
std : : shared_ptr < Placeholder > holder_ ; // holds the memory block if allocated.
DDim dims_ ;
size_t offset_ ; // marks the begin of tensor data area.
} ;
} // namespace framework