@ -291,6 +291,34 @@ class SliceOpGradMaker : public framework::SingleGradOpMaker<T> {
}
} ;
template < typename T >
class SliceDoubleOpGradMaker : public framework : : SingleGradOpMaker < T > {
public :
using framework : : SingleGradOpMaker < T > : : SingleGradOpMaker ;
protected :
std : : unique_ptr < T > Apply ( ) const override {
auto * bind = new T ( ) ;
if ( this - > HasInput ( " StartsTensor " ) ) {
bind - > SetInput ( " StartsTensor " , this - > Input ( " StartsTensor " ) ) ;
}
if ( this - > HasInput ( " EndsTensor " ) ) {
bind - > SetInput ( " EndsTensor " , this - > Input ( " EndsTensor " ) ) ;
}
if ( this - > HasInput ( " StartsTensorList " ) ) {
bind - > SetInput ( " StartsTensorList " , this - > Input ( " StartsTensorList " ) ) ;
}
if ( this - > HasInput ( " EndsTensorList " ) ) {
bind - > SetInput ( " EndsTensorList " , this - > Input ( " EndsTensorList " ) ) ;
}
bind - > SetInput ( " Input " , this - > OutputGrad ( framework : : GradVarName ( " Input " ) ) ) ;
bind - > SetOutput ( " Out " , this - > InputGrad ( framework : : GradVarName ( " Out " ) ) ) ;
bind - > SetAttrMap ( this - > Attrs ( ) ) ;
bind - > SetType ( " slice " ) ;
return std : : unique_ptr < T > ( bind ) ;
}
} ;
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE ( SliceOpGradNoNeedBufferVarsInference ,
" Input " ) ;
@ -302,6 +330,8 @@ REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
ops : : SliceOpGradMaker < paddle : : framework : : OpDesc > ,
ops : : SliceOpGradMaker < paddle : : imperative : : OpBase > ) ;
REGISTER_OPERATOR ( slice_grad , ops : : SliceOpGrad ,
ops : : SliceDoubleOpGradMaker < paddle : : framework : : OpDesc > ,
ops : : SliceDoubleOpGradMaker < paddle : : imperative : : OpBase > ,
ops : : SliceOpGradNoNeedBufferVarsInference ) ;
REGISTER_OP_CPU_KERNEL (