@ -17,6 +17,7 @@ limitations under the License. */
# include <utility>
# include <vector>
# include "paddle/fluid/framework/op_registry.h"
# include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
@ -58,7 +59,12 @@ template <typename DeviceContext, typename T>
class SliceKernel : public framework : : OpKernel < T > {
public :
void Compute ( const framework : : ExecutionContext & ctx ) const override {
int rank = ctx . Input < framework : : Tensor > ( " Input " ) - > dims ( ) . size ( ) ;
const framework : : Variable * input_var = ctx . InputVar ( " Input " ) ;
bool is_tensor_array = input_var - > IsType < framework : : LoDTensorArray > ( ) ;
int rank = is_tensor_array
? 1
: ctx . Input < framework : : Tensor > ( " Input " ) - > dims ( ) . size ( ) ;
switch ( rank ) {
case 1 :
SliceCompute < 1 > ( ctx ) ;
@ -86,17 +92,17 @@ class SliceKernel : public framework::OpKernel<T> {
void SliceCompute ( const framework : : ExecutionContext & context ) const {
auto & place =
* context . template device_context < DeviceContext > ( ) . eigen_device ( ) ;
auto in = context . Input < framework : : Tensor > ( " Input " ) ;
auto out = context . Output < framework : : Tensor > ( " Out " ) ;
auto out_dims = out - > dims ( ) ;
auto in_dims = in - > dims ( ) ;
const framework : : Variable * input_var = context . InputVar ( " Input " ) ;
framework : : Variable * out_var = context . OutputVar ( " Out " ) ;
bool input_is_tensor_array = input_var - > IsType < framework : : LoDTensorArray > ( ) ;
bool out_is_tensor_array = out_var - > IsType < framework : : LoDTensorArray > ( ) ;
auto axes = context . Attr < std : : vector < int > > ( " axes " ) ;
auto starts = context . Attr < std : : vector < int > > ( " starts " ) ;
auto ends = context . Attr < std : : vector < int > > ( " ends " ) ;
auto decrease_axis = context . Attr < std : : vector < int > > ( " decrease_axis " ) ;
auto infer_flags = context . Attr < std : : vector < int > > ( " infer_flags " ) ;
auto list_new_ends_tensor =
context . MultiInput < framework : : Tensor > ( " EndsTensorList " ) ;
auto list_new_starts_tensor =
@ -109,7 +115,6 @@ class SliceKernel : public framework::OpKernel<T> {
if ( list_new_starts_tensor . size ( ) > 0 | | list_new_ends_tensor . size ( ) > 0 ) {
need_infer = true ;
}
if ( need_infer ) {
if ( context . HasInput ( " StartsTensor " ) ) {
auto * starts_tensor = context . Input < framework : : Tensor > ( " StartsTensor " ) ;
@ -117,17 +122,70 @@ class SliceKernel : public framework::OpKernel<T> {
} else if ( list_new_starts_tensor . size ( ) > 0 ) {
starts = get_new_data_from_tensorlist ( list_new_starts_tensor ) ;
}
PADDLE_ENFORCE_EQ (
starts . size ( ) , axes . size ( ) ,
" The size of starts must be equal to the size of axes. " ) ;
if ( context . HasInput ( " EndsTensor " ) ) {
auto * ends_tensor = context . Input < framework : : Tensor > ( " EndsTensor " ) ;
ends = get_new_data_from_tensor ( ends_tensor ) ;
} else if ( list_new_ends_tensor . size ( ) > 0 ) {
ends = get_new_data_from_tensorlist ( list_new_ends_tensor ) ;
}
PADDLE_ENFORCE_EQ ( ends . size ( ) , axes . size ( ) ,
" The size of ends must be equal to the size of axes. " ) ;
}
PADDLE_ENFORCE_EQ (
starts . size ( ) , axes . size ( ) ,
platform : : errors : : InvalidArgument (
" The size of starts must be equal to the size of axes. " ) ) ;
PADDLE_ENFORCE_EQ (
ends . size ( ) , axes . size ( ) ,
platform : : errors : : InvalidArgument (
" The size of ends must be equal to the size of axes. " ) ) ;
if ( input_is_tensor_array ) {
auto in_array = context . Input < framework : : LoDTensorArray > ( " Input " ) ;
// If the input is LoDTensorArray, the rank of input is 1.
int in_size = in_array - > size ( ) ;
int start = starts [ 0 ] < 0 ? ( starts [ 0 ] + in_size ) : starts [ 0 ] ;
int end = ends [ 0 ] < 0 ? ( ends [ 0 ] + in_size ) : ends [ 0 ] ;
start = std : : max ( start , 0 ) ;
end = std : : max ( end , 0 ) ;
end = std : : min ( end , in_size ) ;
PADDLE_ENFORCE_GT ( end , start ,
platform : : errors : : InvalidArgument (
" Attr(ends) should be greater than attr(starts) in "
" slice op. But received ends = %d, starts = %d. " ,
end , start ) ) ;
int out_size = end - start ;
if ( out_is_tensor_array ) {
auto out_array = context . Output < framework : : LoDTensorArray > ( " Out " ) ;
out_array - > resize ( out_size ) ;
for ( int i = 0 ; i < out_size ; + + i ) {
auto * out_tensor = & out_array - > at ( i ) ;
auto in_tensor = in_array - > at ( i + start ) ;
out_tensor - > set_lod ( in_tensor . lod ( ) ) ;
if ( in_tensor . memory_size ( ) > 0 ) {
TensorCopy ( in_tensor , context . GetPlace ( ) , out_tensor ) ;
} else {
VLOG ( 10 )
< < " WARNING: The input tensor 'x_tensor' holds no memory, so "
" nothing has been written to output array[ "
< < i < < " ]. " ;
}
}
} else {
auto out = context . Output < framework : : Tensor > ( " Out " ) ;
auto in_tensor = in_array - > at ( start ) ;
TensorCopy ( in_tensor , context . GetPlace ( ) , out ) ;
}
return ;
}
auto in = context . Input < framework : : Tensor > ( " Input " ) ;
auto out = context . Output < framework : : Tensor > ( " Out " ) ;
auto out_dims = out - > dims ( ) ;
auto in_dims = in - > dims ( ) ;
if ( need_infer ) {
out_dims = in_dims ;
int dim_value , start , end ;
for ( size_t i = 0 ; i < axes . size ( ) ; + + i ) {
@ -233,7 +291,12 @@ template <typename DeviceContext, typename T>
class SliceGradKernel : public framework : : OpKernel < T > {
public :
void Compute ( const framework : : ExecutionContext & ctx ) const override {
size_t rank = ctx . Input < framework : : Tensor > ( " Input " ) - > dims ( ) . size ( ) ;
const framework : : Variable * input_var = ctx . InputVar ( " Input " ) ;
bool is_tensor_array = input_var - > IsType < framework : : LoDTensorArray > ( ) ;
size_t rank = is_tensor_array
? 1
: ctx . Input < framework : : Tensor > ( " Input " ) - > dims ( ) . size ( ) ;
switch ( rank ) {
case 1 :
SliceCompute < 1 > ( ctx ) ;
@ -261,17 +324,9 @@ class SliceGradKernel : public framework::OpKernel<T> {
void SliceCompute ( const framework : : ExecutionContext & context ) const {
auto & place =
* context . template device_context < DeviceContext > ( ) . eigen_device ( ) ;
auto * d_out =
context . Input < framework : : Tensor > ( framework : : GradVarName ( " Out " ) ) ;
auto * d_input =
context . Output < framework : : Tensor > ( framework : : GradVarName ( " Input " ) ) ;
d_input - > mutable_data < T > ( context . GetPlace ( ) ) ;
auto out_dims = d_out - > dims ( ) ;
auto in_dims = d_input - > dims ( ) ;
auto axes = context . Attr < std : : vector < int > > ( " axes " ) ;
auto starts = context . Attr < std : : vector < int > > ( " starts " ) ;
auto ends = context . Attr < std : : vector < int > > ( " ends " ) ;
auto list_new_ends_tensor =
context . MultiInput < framework : : Tensor > ( " EndsTensorList " ) ;
auto list_new_starts_tensor =
@ -290,6 +345,66 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto * ends_tensor = context . Input < framework : : Tensor > ( " EndsTensor " ) ;
ends = get_new_data_from_tensor ( ends_tensor ) ;
}
framework : : Variable * d_input_var =
context . OutputVar ( framework : : GradVarName ( " Input " ) ) ;
const framework : : Variable * d_out_var =
context . InputVar ( framework : : GradVarName ( " Out " ) ) ;
bool d_input_is_tensor_array =
d_input_var - > IsType < framework : : LoDTensorArray > ( ) ;
bool d_out_is_tensor_array = d_out_var - > IsType < framework : : LoDTensorArray > ( ) ;
if ( d_input_is_tensor_array ) {
auto * input_array = context . Input < framework : : LoDTensorArray > ( " Input " ) ;
auto * d_input_array = context . Output < framework : : LoDTensorArray > (
framework : : GradVarName ( " Input " ) ) ;
int d_in_size = input_array - > size ( ) ;
d_input_array - > resize ( d_in_size ) ;
// If the input is LoDTensorArray, the rank of input is 1.
// So only use the 0th element of starts.
int start = starts [ 0 ] < 0 ? ( starts [ 0 ] + d_in_size ) : starts [ 0 ] ;
start = std : : max ( start , 0 ) ;
// set zero
platform : : DeviceContextPool & pool =
platform : : DeviceContextPool : : Instance ( ) ;
auto & dev_ctx = * pool . Get ( context . GetPlace ( ) ) ;
T value = 0.0 ;
math : : SetConstant < DeviceContext , T > functor ;
for ( int i = 0 ; i < d_in_size ; + + i ) {
auto dim = input_array - > at ( i ) . dims ( ) ;
d_input_array - > at ( i ) . Resize ( dim ) ;
d_input_array - > at ( i ) . mutable_data < T > ( context . GetPlace ( ) ) ;
functor ( reinterpret_cast < const DeviceContext & > ( dev_ctx ) ,
& d_input_array - > at ( i ) , static_cast < T > ( value ) ) ;
}
if ( d_out_is_tensor_array ) {
auto * d_out_array = context . Input < framework : : LoDTensorArray > (
framework : : GradVarName ( " Out " ) ) ;
int d_out_size = d_out_array - > size ( ) ;
for ( int i = 0 ; i < d_out_size ; + + i ) {
TensorCopy ( d_out_array - > at ( i ) , context . GetPlace ( ) ,
& ( d_input_array - > at ( start + i ) ) ) ;
}
} else {
auto * d_out =
context . Input < framework : : Tensor > ( framework : : GradVarName ( " Out " ) ) ;
TensorCopy ( * d_out , context . GetPlace ( ) , & ( d_input_array - > at ( start ) ) ) ;
}
return ;
}
auto * d_out =
context . Input < framework : : Tensor > ( framework : : GradVarName ( " Out " ) ) ;
auto * d_input =
context . Output < framework : : Tensor > ( framework : : GradVarName ( " Input " ) ) ;
d_input - > mutable_data < T > ( context . GetPlace ( ) ) ;
auto out_dims = d_out - > dims ( ) ;
auto in_dims = d_input - > dims ( ) ;
auto decrease_axis = context . Attr < std : : vector < int > > ( " decrease_axis " ) ;
if ( decrease_axis . size ( ) > 0 ) {