@ -50,29 +50,56 @@ class ReshapeOp : public framework::OperatorWithKernel {
: OperatorWithKernel ( type , inputs , outputs , attrs ) { }
: OperatorWithKernel ( type , inputs , outputs , attrs ) { }
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) ,
PADDLE_ENFORCE _EQ ( ctx - > HasInput ( " X " ) , true ,
" Input(X) of ReshapeOp should not be null. " ) ;
" Input(X) of ReshapeOp should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) ,
PADDLE_ENFORCE _EQ ( ctx - > HasOutput ( " Out " ) , true ,
" Output(Out) of ReshapeOp should not be null. " ) ;
" Output(Out) of ReshapeOp should not be null. " ) ;
if ( ctx - > HasInputs ( " ShapeTensor " ) ) {
if ( ctx - > HasInputs ( " ShapeTensor " ) ) {
// top prority shape
// top prority shape
auto inputs_name = ctx - > Inputs ( " ShapeTensor " ) ;
auto ShapeTensor = ctx - > Inputs ( " ShapeTensor " ) ;
PADDLE_ENFORCE ( inputs_name . size ( ) > 0 , " shape tensor size can't be zero " ) ;
PADDLE_ENFORCE_GT ( ShapeTensor . size ( ) , 0 ,
auto out_dims = std : : vector < int > ( inputs_name . size ( ) , - 1 ) ;
" The size of Input(ShapeTensor) can't be zero " ) ;
ctx - > SetOutputDim ( " Out " , framework : : make_ddim ( out_dims ) ) ;
auto infer_shape = ctx - > Attrs ( ) . Get < std : : vector < int > > ( " shape " ) ;
const int64_t copy_dim_val = 0 ;
auto in_dims = ctx - > GetInputDim ( " X " ) ;
for ( size_t i = 0 ; i < infer_shape . size ( ) ; + + i ) {
if ( infer_shape [ i ] = = copy_dim_val ) {
PADDLE_ENFORCE_LT (
static_cast < int > ( i ) , in_dims . size ( ) ,
" The dimension of data to copy from input must be less "
" than the dimension of input. " ) ;
infer_shape [ i ] = in_dims [ i ] ;
}
}
auto infer_out_dims = framework : : make_ddim ( infer_shape ) ;
ctx - > SetOutputDim ( " Out " , infer_out_dims ) ;
return ;
}
const std : : vector < int > & shape = ctx - > Attrs ( ) . Get < std : : vector < int > > ( " shape " ) ;
if ( ctx - > HasInput ( " Shape " ) & & shape . empty ( ) ) {
auto shape_dims = ctx - > GetInputDim ( " Shape " ) ;
int num_ele = 1 ;
for ( int i = 0 ; i < shape_dims . size ( ) ; + + i ) {
num_ele * = shape_dims [ i ] ;
}
auto vec_dims = std : : vector < int > ( num_ele , - 1 ) ;
auto out_dims = framework : : make_ddim ( vec_dims ) ;
ctx - > SetOutputDim ( " Out " , out_dims ) ;
ctx - > ShareLoD ( " X " , /*->*/ " Out " ) ;
return ;
return ;
}
}
if ( ctx - > HasInput ( " Shape " ) & & ctx - > IsRuntime ( ) ) {
if ( ctx - > HasInput ( " Shape " ) & & ! shape . empty ( ) & & ctx - > IsRuntime ( ) ) {
// If true, set the shape of Output(Out) according to Input(Shape) in
// If true, set the shape of Output(Out) according to Input(Shape) in
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
ctx - > ShareLoD ( " X " , /*->*/ " Out " ) ;
ctx - > ShareLoD ( " X " , /*->*/ " Out " ) ;
return ;
return ;
}
}
const std : : vector < int > & shape = ctx - > Attrs ( ) . Get < std : : vector < int > > ( " shape " ) ;
PADDLE_ENFORCE ( ! shape . empty ( ) ,
PADDLE_ENFORCE _EQ ( ! shape . empty ( ) , true ,
" The shape information must be set by Attr(shape). " ) ;
" The shape information must be set by Attr(shape). " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
auto out_dims = ValidateShape ( shape , x_dims ) ;
auto out_dims = ValidateShape ( shape , x_dims ) ;
ctx - > SetOutputDim ( " Out " , out_dims ) ;
ctx - > SetOutputDim ( " Out " , out_dims ) ;
@ -99,18 +126,18 @@ class ReshapeOp : public framework::OperatorWithKernel {
int unk_dim_idx = - 1 ;
int unk_dim_idx = - 1 ;
for ( size_t i = 0 ; i < shape . size ( ) ; + + i ) {
for ( size_t i = 0 ; i < shape . size ( ) ; + + i ) {
if ( shape [ i ] = = unk_dim_val ) {
if ( shape [ i ] = = unk_dim_val ) {
PADDLE_ENFORCE (
PADDLE_ENFORCE _EQ (
unk_dim_idx = = - 1 ,
unk_dim_idx , - 1 ,
" Only one input dimension of Attr(shape) can be unknown. " ) ;
" Only one input dimension of Attr(shape) can be unknown. " ) ;
unk_dim_idx = i ;
unk_dim_idx = i ;
} else if ( shape [ i ] = = copy_dim_val ) {
} else if ( shape [ i ] = = copy_dim_val ) {
PADDLE_ENFORCE (
PADDLE_ENFORCE _LT (
static_cast < int > ( i ) < in_dims . size ( ) ,
static_cast < int > ( i ) , in_dims . size ( ) ,
" The index of dimension to copy from input shape must be less "
" The index of dimension to copy from input shape must be less "
" than the size of input shape. " ) ;
" than the size of input shape. " ) ;
} else {
} else {
PADDLE_ENFORCE (
PADDLE_ENFORCE _GT (
shape [ i ] > 0 ,
shape [ i ] , 0 ,
" Each input dimension of Attr(shape) must not be negtive except "
" Each input dimension of Attr(shape) must not be negtive except "
" one unknown dimension. " ) ;
" one unknown dimension. " ) ;
}
}
@ -231,9 +258,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel ( type , inputs , outputs , attrs ) { }
: OperatorWithKernel ( type , inputs , outputs , attrs ) { }
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) shouldn't be null. " ) ;
PADDLE_ENFORCE _EQ ( ctx - > HasInput ( " X " ) , true , " Input(X) shouldn't be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) ,
PADDLE_ENFORCE _EQ ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) , true ,
" Input(Out@GRAD) shouldn't be null. " ) ;
" Input(Out@GRAD) shouldn't be null. " ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , ctx - > GetInputDim ( " X " ) ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , ctx - > GetInputDim ( " X " ) ) ;
}
}
@ -314,8 +341,8 @@ class Reshape2Op : public ReshapeOp {
: ReshapeOp ( type , inputs , outputs , attrs ) { }
: ReshapeOp ( type , inputs , outputs , attrs ) { }
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasOutput ( " XShape " ) ,
PADDLE_ENFORCE _EQ ( ctx - > HasOutput ( " XShape " ) , true ,
" Output(XShape) of ReshapeOp should not be null. " ) ;
" Output(XShape) of ReshapeOp should not be null. " ) ;
const auto & x_dims = ctx - > GetInputDim ( " X " ) ;
const auto & x_dims = ctx - > GetInputDim ( " X " ) ;
std : : vector < int64_t > xshape_dims ( x_dims . size ( ) + 1 ) ;
std : : vector < int64_t > xshape_dims ( x_dims . size ( ) + 1 ) ;
xshape_dims [ 0 ] = 0 ;
xshape_dims [ 0 ] = 0 ;
@ -365,9 +392,10 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
: OperatorWithKernel ( type , inputs , outputs , attrs ) { }
: OperatorWithKernel ( type , inputs , outputs , attrs ) { }
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " XShape " ) , " Input(XShape) shouldn't be null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " XShape " ) , true ,
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) ,
" Input(XShape) shouldn't be null. " ) ;
" Input(Out@GRAD) shouldn't be null. " ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) , true ,
" Input(Out@GRAD) shouldn't be null. " ) ;
auto xshape_dims = ctx - > GetInputDim ( " XShape " ) ;
auto xshape_dims = ctx - > GetInputDim ( " XShape " ) ;
auto x_dims = framework : : slice_ddim ( xshape_dims , 1 , xshape_dims . size ( ) ) ;
auto x_dims = framework : : slice_ddim ( xshape_dims , 1 , xshape_dims . size ( ) ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , x_dims ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , x_dims ) ;