@ -28,14 +28,15 @@ class ExpandOp : public framework::OperatorWithKernel {
protected :
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) , " Output(Out) should not be null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " X " ) , true , " Input(X) should not be null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasOutput ( " Out " ) , true ,
" Output(Out) should not be null. " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
std : : vector < int > expand_times ( x_dims . size ( ) , - 1 ) ;
auto expand_times = ctx - > Attrs ( ) . Get < std : : vector < int > > ( " expand_times " ) ;
if ( ! ctx - > HasInputs ( " expand_times_tensor " ) ) {
expand_times = ctx- > Attrs ( ) . Get < std: : vector < int > > ( " expand_times " ) ;
if ( expand_times . size ( ) = = 0 ) {
expand_times = std: : vector < int > ( x_dims . size ( ) , - 1 ) ;
}
PADDLE_ENFORCE_EQ ( static_cast < size_t > ( x_dims . size ( ) ) , expand_times . size ( ) ,
@ -49,6 +50,9 @@ class ExpandOp : public framework::OperatorWithKernel {
if ( x_dims [ i ] = = - 1 | | expand_times [ i ] = = - 1 ) {
out_shape [ i ] = - 1 ;
} else {
PADDLE_ENFORCE_GT (
expand_times [ i ] , 0 ,
" The element of Attr(expand_times) must greater than 0. " ) ;
out_shape [ i ] = x_dims [ i ] * expand_times [ i ] ;
}
}
@ -69,7 +73,7 @@ class ExpandOp : public framework::OperatorWithKernel {
framework : : OpKernelType GetKernelTypeForVar (
const std : : string & var_name , const Tensor & tensor ,
const framework : : OpKernelType & expected_kernel_type ) const override {
if ( var_name = = " expand_times_tensor " ) {
if ( var_name = = " expand_times_tensor " | | var_name = = " ExpandTimes " ) {
return expected_kernel_type ;
}
return framework : : OpKernelType ( expected_kernel_type . data_type_ ,
@ -83,7 +87,15 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput ( " X " ,
" (Tensor, default Tensor<float>). A tensor with rank in [1, 6]. "
" X is the input to be expanded. " ) ;
AddInput ( " expand_times_tensor " , " (Tensor Tensor<int>), epxand times for X " )
AddInput ( " ExpandTimes " ,
" (Tensor<int>), optional). If provided, expand according to "
" this given expand times. It has a higher priority than "
" expand_times_tensor and expand_times. " )
. AsDispensable ( ) ;
AddInput ( " expand_times_tensor " ,
" (Tensor Tensor<int>), epxand times for X. "
" It has a higher priority than expand_times, but a lower priority "
" than ExpandTimes " )
. AsDuplicable ( )
. AsDispensable ( ) ;
AddOutput ( " Out " ,
@ -127,9 +139,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected :
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) ,
" Input(Out@GRAD) should not be null. " ) ;
PADDLE_ENFORCE _EQ ( ctx - > HasInput ( " X " ) , true , " Input(X) should not be null. " ) ;
PADDLE_ENFORCE _EQ ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) , true ,
" Input(Out@GRAD) should not be null. " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
std : : vector < int > expand_times =
@ -147,12 +159,15 @@ class ExpandGradOp : public framework::OperatorWithKernel {
}
for ( size_t i = start_pos ; i < expand_times . size ( ) ; + + i ) {
PADDLE_ENFORCE_EQ ( x_dims [ i ] * expand_times [ i ] , out_dims [ i ] ,
" Each dimension size of Input(Out@GRAD) should be "
" equal to multiplication of crroresponding dimension "
" size of Input(X) and Attr(expand_times) value. " ) ;
if ( expand_times [ i ] = = - 1 ) {
continue ;
} else {
PADDLE_ENFORCE_EQ ( x_dims [ i ] * expand_times [ i ] , out_dims [ i ] ,
" Each dimension size of Input(Out@GRAD) should be "
" equal to multiplication of crroresponding dimension "
" size of Input(X) and Attr(expand_times) value. " ) ;
}
}
auto x_grad_name = framework : : GradVarName ( " X " ) ;
if ( ctx - > HasOutput ( x_grad_name ) ) {
@ -191,6 +206,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
op - > SetInput ( framework : : GradVarName ( " Out " ) , OutputGrad ( " Out " ) ) ;
op - > SetOutput ( framework : : GradVarName ( " X " ) , InputGrad ( " X " ) ) ;
op - > SetInput ( " expand_times_tensor " , Input ( " expand_times_tensor " ) ) ;
op - > SetInput ( " ExpandTimes " , Input ( " ExpandTimes " ) ) ;
op - > SetAttrMap ( Attrs ( ) ) ;
return op ;
}