@ -23,6 +23,7 @@
# include "paddle/fluid/framework/tensor_util.h"
# include "paddle/fluid/operators/assign_value_op.h"
# include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
# include "paddle/fluid/operators/utils.h"
# include "paddle/fluid/platform/enforce.h"
namespace paddle {
@ -58,26 +59,70 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
return value_name ;
}
inline void CheckAndUpdateSlice ( const framework : : DDim in_dims ,
const std : : vector < int64_t > axes ,
std : : vector < int64_t > * starts ,
std : : vector < int64_t > * ends ,
std : : vector < int64_t > * steps ) {
for ( size_t i = 0 ; i < axes . size ( ) ; + + i ) {
int64_t axis = axes [ i ] ;
int64_t dim_value = in_dims [ axis ] ;
int64_t start =
( * starts ) [ i ] < 0 ? ( ( * starts ) [ i ] + dim_value ) : ( * starts ) [ i ] ;
int64_t end = ( * ends ) [ i ] < 0 ? ( ( * ends ) [ i ] + dim_value ) : ( * ends ) [ i ] ;
start = std : : max ( start , static_cast < int64_t > ( 0 ) ) ;
end = std : : min ( end , dim_value ) ;
int64_t step = ( * steps ) [ i ] ;
PADDLE_ENFORCE_NE (
step , 0 , platform : : errors : : InvalidArgument (
" Step should not be 0, but received step = %d. " , step ) ) ;
if ( step > 0 ) {
start = std : : min ( start , dim_value ) ;
end = std : : max ( end , static_cast < int64_t > ( 0 ) ) ;
PADDLE_ENFORCE_GT (
end , start ,
platform : : errors : : InvalidArgument (
" When step > 0, end should be greater than start, but "
" received end = %d, start = %d. " ,
end , start ) ) ;
} else {
// NOTE(liym27): When step < 0, start should less and equal to dim_value-1
// "end is -1" means contain the 0-th element of this axis.
start = std : : min ( start , dim_value - 1 ) ;
end = std : : max ( end , static_cast < int64_t > ( - 1 ) ) ;
PADDLE_ENFORCE_GT (
start , end ,
platform : : errors : : InvalidArgument (
" When step < 0, start should be greater than end, but "
" received start = %d, end = %d. " ,
start , end ) ) ;
}
( * starts ) [ i ] = start ;
( * ends ) [ i ] = end ;
}
}
inline framework : : DDim GetSliceDims ( const framework : : DDim in_dims ,
const std : : vector < int64_t > axes ,
const std : : vector < int64_t > starts ,
const std : : vector < int64_t > ends ) {
const std : : vector < int64_t > ends ,
const std : : vector < int64_t > steps ) {
framework : : DDim slice_dims ( in_dims ) ;
for ( size_t i = 0 ; i < axes . size ( ) ; + + i ) {
int64_t axis = axes [ i ] ;
int64_t dim_value = in_dims [ axis ] ;
int64_t start = starts [ i ] ;
int64_t end = ends [ i ] ;
int64_t step = steps [ i ] ;
int64_t start = starts [ i ] < 0 ? ( starts [ i ] + dim_value ) : starts [ i ] ;
int64_t end = ends [ i ] < 0 ? ( ends [ i ] + dim_value ) : ends [ i ] ;
start = std : : max ( start , static_cast < int64_t > ( 0 ) ) ;
end = std : : min ( end , dim_value ) ;
PADDLE_ENFORCE_GT ( end , start , platform : : errors : : InvalidArgument (
" end should greater than start, but "
" received end = %d, start = %d " ,
end , start ) ) ;
slice_dims [ axis ] = end - start ;
if ( step > 0 ) {
slice_dims [ axis ] = ( end - start + step - 1 ) / step ;
} else {
slice_dims [ axis ] = ( end - start + step + 1 ) / step ;
}
}
return slice_dims ;
}
@ -120,19 +165,36 @@ class SetValueKernel : public framework::OpKernel<T> {
template < size_t D >
void SetValueCompute ( const framework : : ExecutionContext & ctx ) const {
auto * in = ctx . Input < framework : : LoDTensor > ( " Input " ) ;
auto * value_tensor = ctx . Input < framework : : LoDTensor > ( " ValueTensor " ) ;
auto * out = ctx . Output < framework : : LoDTensor > ( " Out " ) ;
auto starts_tensor_list =
ctx . MultiInput < framework : : Tensor > ( " StartsTensorList " ) ;
auto ends_tensor_list = ctx . MultiInput < framework : : Tensor > ( " EndsTensorList " ) ;
auto steps_tensor_list =
ctx . MultiInput < framework : : Tensor > ( " StepsTensorList " ) ;
auto dtype =
static_cast < framework : : proto : : VarType : : Type > ( ctx . Attr < int > ( " dtype " ) ) ;
auto axes = ctx . Attr < std : : vector < int64_t > > ( " axes " ) ;
auto starts = ctx . Attr < std : : vector < int64_t > > ( " starts " ) ;
auto ends = ctx . Attr < std : : vector < int64_t > > ( " ends " ) ;
auto steps = ctx . Attr < std : : vector < int64_t > > ( " steps " ) ;
auto shape = ctx . Attr < std : : vector < int64_t > > ( " shape " ) ;
auto * value_tensor = ctx . Input < framework : : LoDTensor > ( " ValueTensor " ) ;
if ( ! starts_tensor_list . empty ( ) ) {
starts = GetDataFromTensorList < int64_t > ( starts_tensor_list ) ;
}
if ( ! ends_tensor_list . empty ( ) ) {
ends = GetDataFromTensorList < int64_t > ( ends_tensor_list ) ;
}
if ( ! steps_tensor_list . empty ( ) ) {
steps = GetDataFromTensorList < int64_t > ( steps_tensor_list ) ;
}
auto in_dims = in - > dims ( ) ;
auto value_dims = framework : : make_ddim ( shape ) ;
auto slice_dims = GetSliceDims ( in_dims , axes , starts , ends ) ;
CheckAndUpdateSlice ( in_dims , axes , & starts , & ends , & steps ) ;
auto slice_dims = GetSliceDims ( in_dims , axes , starts , ends , steps );
auto place = ctx . GetPlace ( ) ;
auto & eigen_place =
@ -160,46 +222,37 @@ class SetValueKernel : public framework::OpKernel<T> {
auto slice_e = framework : : EigenTensor < T , D > : : From ( slice_t , slice_dims ) ;
// Step 1: Set the value of out at `_index` to zero
// - Step 1.1 Get a slice tensor from out
Eigen : : array < int64_t , D > offsets , extents ;
Eigen : : array < std : : pair < int64_t , int64_t > , D > paddings ;
slice_e . device ( eigen_place ) = slice_e . constant ( T ( 0 ) ) ;
auto starts_indices = Eigen : : DSizes < Eigen : : DenseIndex , D > ( ) ;
auto ends_indices = Eigen : : DSizes < Eigen : : DenseIndex , D > ( ) ;
auto strides_indices = Eigen : : DSizes < Eigen : : DenseIndex , D > ( ) ;
for ( size_t i = 0 ; i < D ; + + i ) {
offsets [ i ] = 0 ;
extents [ i ] = slice_dims [ i ] ;
}
int64_t start ;
for ( size_t i = 0 ; i < axes . size ( ) ; + + i ) {
start = starts [ i ] < 0 ? ( starts [ i ] + in_dims [ axes [ i ] ] ) : starts [ i ] ;
start = std : : max ( start , static_cast < int64_t > ( 0 ) ) ;
offsets [ axes [ i ] ] = start ;
starts_indices [ i ] = 0 ;
ends_indices [ i ] = slice_dims [ i ] ;
strides_indices [ i ] = 1 ;
}
for ( size_t i = 0 ; i < paddings . size ( ) ; + + i ) {
paddings [ i ] . first = offsets [ i ] ;
paddings [ i ] . second = ( in_dims [ i ] - slice_dims [ i ] ) - offsets [ i ] ;
for ( size_t i = 0 ; i < axes . size ( ) ; i + + ) {
int axis_index = axes [ i ] ;
starts_indices [ axis_index ] = starts [ i ] ;
ends_indices [ axis_index ] = ends [ i ] ;
strides_indices [ axis_index ] = steps [ i ] ;
}
slice_e . device ( eigen_place ) = out_e . slice ( offsets , extents ) ;
// - Step 1.2 Get paded tensor by padding 0 to slice tensor
pad_e . device ( eigen_place ) = slice_e . pad ( paddings , T ( 0 ) ) ;
// - Step 1.3 Set 0 at `_index` of out tensor
out_e . device ( eigen_place ) = out_e - pad_e ;
out_e . stridedSlice ( starts_indices , ends_indices , strides_indices )
. device ( eigen_place ) = slice_e ;
// Step 2: Set a tensor with the same shape as out tensor. And its data at
// '_index' is the same as value_tensor, and data out of '_index' to zero
// - Step 2.1 Set the data of slice tensor to 0
slice_e . device ( eigen_place ) = slice_e . constant ( T ( 0 ) ) ;
// - Step 2.2 Set slice tensor with value
// - Step 2.1 Set slice tensor with value
if ( value_tensor ! = nullptr ) {
// ElementwiseComputeEx can do broadcasting
ElementwiseComputeEx < SubFunctor < T > , DeviceContext , T > (
ctx , & slice_t , value_tensor , - 1 , SubFunctor < T > ( ) , & slice_t ) ;
} else {
Tensor value_t ( dtype ) ;
auto value_dims = framework : : make_ddim ( shape ) ;
value_t . mutable_data < T > ( value_dims , place ) ;
auto value_name = GetValueName ( dtype ) ;
CopyVecotorToTensor < T > ( value_name . c_str ( ) , & value_t , ctx ) ;
@ -208,8 +261,10 @@ class SetValueKernel : public framework::OpKernel<T> {
ctx , & slice_t , & value_t , - 1 , SubFunctor < T > ( ) , & slice_t ) ;
}
// - Step 2.3 Pad slice tensor with 0
pad_e . device ( eigen_place ) = slice_e . pad ( paddings , T ( 0 ) ) ;
// - Step 2.2 Pad slice tensor with 0
pad_e . device ( eigen_place ) = pad_e . constant ( T ( 0 ) ) ;
pad_e . stridedSlice ( starts_indices , ends_indices , strides_indices )
. device ( eigen_place ) = slice_e ;
// Step 3: Set out tensor with value_tensor
out_e . device ( eigen_place ) = out_e - pad_e ;