@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License . */
limitations under the License . */
# include "paddle/fluid/operators/transpose_op.h"
# include "paddle/fluid/operators/transpose_op.h"
# include <string>
# include <vector>
# include <vector>
namespace paddle {
namespace paddle {
@ -24,7 +25,7 @@ class TransposeOp : public framework::OperatorWithKernel {
public :
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) , " Output(Out) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) , " Output(Out) should not be null " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
auto x_dims = ctx - > GetInputDim ( " X " ) ;
@ -90,7 +91,7 @@ The behavior of this operator is similar to how `numpy.transpose` works.
2 & 5
2 & 5
\ end { pmatrix } $ $
\ end { pmatrix } $ $
- Given a input tensor with shape $ ( N , C , H , W ) $ and the ` axes ` is
- Given a input tensor with shape $ ( N , C , H , W ) $ and the ` axes ` is
$ [ 0 , 2 , 3 , 1 ] $ , then shape of the output tensor will be : $ ( N , H , W , C ) $ .
$ [ 0 , 2 , 3 , 1 ] $ , then shape of the output tensor will be : $ ( N , H , W , C ) $ .
) DOC " );
) DOC " );
@ -101,7 +102,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
public :
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) , " Input(X) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) ,
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) ,
" Input(Out@GRAD) should not be null " ) ;
" Input(Out@GRAD) should not be null " ) ;
@ -113,6 +114,93 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
}
}
} ;
} ;
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class Transpose2Op : public TransposeOp {
public :
Transpose2Op ( const std : : string & type ,
const framework : : VariableNameMap & inputs ,
const framework : : VariableNameMap & outputs ,
const framework : : AttributeMap & attrs )
: TransposeOp ( type , inputs , outputs , attrs ) { }
void InferShape ( framework : : InferShapeContext * ctx ) const override {
TransposeOp : : InferShape ( ctx ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " XShape " ) ,
" Output(XShape) should not be null " ) ;
const auto & in_dims = ctx - > GetInputDim ( " X " ) ;
std : : vector < int64_t > x_shape_dim ( in_dims . size ( ) + 1 ) ;
x_shape_dim [ 0 ] = 0 ;
for ( int i = 0 ; i < in_dims . size ( ) ; + + i ) {
x_shape_dim [ i + 1 ] = in_dims [ i ] ;
}
ctx - > SetOutputDim ( " XShape " , framework : : make_ddim ( x_shape_dim ) ) ;
ctx - > ShareLoD ( " X " , /*->*/ " XShape " ) ;
}
protected :
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
return framework : : OpKernelType (
framework : : ToDataType ( ctx . Input < framework : : LoDTensor > ( " X " ) - > type ( ) ) ,
ctx . device_context ( ) ) ;
}
} ;
class Transpose2OpMaker : public TransposeOpMaker {
public :
void Make ( ) override {
TransposeOpMaker : : Make ( ) ;
AddOutput ( " XShape " , " (Tensor)The output tensor. " ) . AsIntermediate ( ) ;
}
} ;
class Transpose2GradMaker : public framework : : SingleGradOpDescMaker {
public :
using framework : : SingleGradOpDescMaker : : SingleGradOpDescMaker ;
std : : unique_ptr < framework : : OpDesc > Apply ( ) const override {
auto * grad_op = new framework : : OpDesc ( ) ;
grad_op - > SetType ( " transpose2_grad " ) ;
grad_op - > SetInput ( " XShape " , Output ( " XShape " ) ) ;
grad_op - > SetInput ( framework : : GradVarName ( " Out " ) , OutputGrad ( " Out " ) ) ;
grad_op - > SetOutput ( framework : : GradVarName ( " X " ) , InputGrad ( " X " ) ) ;
grad_op - > SetAttrMap ( Attrs ( ) ) ;
return std : : unique_ptr < framework : : OpDesc > ( grad_op ) ;
}
} ;
class Transpose2OpGrad : public framework : : OperatorWithKernel {
public :
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " XShape " ) , " Input(XShape) should not be null " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( framework : : GradVarName ( " Out " ) ) ,
" Input(Out@GRAD) should not be null " ) ;
if ( ctx - > HasOutput ( framework : : GradVarName ( " X " ) ) ) {
auto xshape_dim = ctx - > GetInputDim ( " XShape " ) ;
auto x_shape_dim =
framework : : slice_ddim ( xshape_dim , 1 , xshape_dim . size ( ) ) ;
ctx - > SetOutputDim ( framework : : GradVarName ( " X " ) , x_shape_dim ) ;
ctx - > ShareLoD ( " XShape " , framework : : GradVarName ( " X " ) ) ;
}
}
protected :
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
return framework : : OpKernelType (
framework : : ToDataType (
ctx . Input < framework : : LoDTensor > ( framework : : GradVarName ( " Out " ) )
- > type ( ) ) ,
ctx . device_context ( ) ) ;
}
} ;
} // namespace operators
} // namespace operators
} // namespace paddle
} // namespace paddle
@ -120,8 +208,20 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR ( transpose , ops : : TransposeOp , ops : : TransposeOpMaker ,
REGISTER_OPERATOR ( transpose , ops : : TransposeOp , ops : : TransposeOpMaker ,
paddle : : framework : : DefaultGradOpDescMaker < true > ) ;
paddle : : framework : : DefaultGradOpDescMaker < true > ) ;
REGISTER_OPERATOR ( transpose_grad , ops : : TransposeOpGrad ) ;
REGISTER_OPERATOR ( transpose_grad , ops : : TransposeOpGrad ) ;
REGISTER_OP_CPU_KERNEL (
REGISTER_OP_CPU_KERNEL (
transpose , ops : : TransposeKernel < paddle : : platform : : CPUDeviceContext , float > ) ;
transpose , ops : : TransposeKernel < paddle : : platform : : CPUDeviceContext , float > ) ;
REGISTER_OP_CPU_KERNEL (
REGISTER_OP_CPU_KERNEL (
transpose_grad ,
transpose_grad ,
ops : : TransposeGradKernel < paddle : : platform : : CPUDeviceContext , float > ) ;
ops : : TransposeGradKernel < paddle : : platform : : CPUDeviceContext , float > ) ;
REGISTER_OPERATOR ( transpose2 , ops : : Transpose2Op , ops : : Transpose2OpMaker ,
ops : : Transpose2GradMaker ) ;
REGISTER_OPERATOR ( transpose2_grad , ops : : Transpose2OpGrad ) ;
REGISTER_OP_CPU_KERNEL (
transpose2 ,
ops : : TransposeKernel < paddle : : platform : : CPUDeviceContext , float > ) ;
REGISTER_OP_CPU_KERNEL (
transpose2_grad ,
ops : : TransposeGradKernel < paddle : : platform : : CPUDeviceContext , float > ) ;