From e2e82bde32709a0bedaf940c60c3d5e3b73d22b1 Mon Sep 17 00:00:00 2001 From: minqiyang <minqiyang@baidu.com> Date: Thu, 11 Oct 2018 21:12:56 +0800 Subject: [PATCH] Accelerate Reshape op --- paddle/fluid/operators/reshape_op.cc | 82 ++++++++++++-------- paddle/fluid/operators/sequence_concat_op.cc | 5 +- 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index d72f85f2c4..b8fdc3f826 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of [2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input. 3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while -Attr(shape) still should be set correctly to gurantee shape inference in +Attr(shape) still should be set correctly to gurantee shape inference in compile-time. )DOC"); @@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { } }; +template <typename T> class ReshapeKernel { public: void operator()(const framework::ExecutionContext &ctx) const { @@ -227,12 +228,15 @@ class ReshapeKernel { "sequence_reshape op."); } - out->mutable_data(ctx.GetPlace(), in->type()); - framework::TensorCopySync(*in, ctx.GetPlace(), out); + if (in->data<T>() != + reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) { + framework::TensorCopySync(*in, ctx.GetPlace(), out); + } out->Resize(out_dims); } }; +template <typename T> class ReshapeGradKernel { public: void operator()(const framework::ExecutionContext &ctx) const { @@ -240,8 +244,9 @@ class ReshapeGradKernel { auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto in_dims = d_x->dims(); - d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) { + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + } d_x->Resize(in_dims); } }; @@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp { : ReshapeOp(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext *ctx) const override { - ReshapeOp::InferShape(ctx); PADDLE_ENFORCE(ctx->HasOutput("XShape"), "Output(XShape) of ReshapeOp should not be null."); const auto &x_dims = ctx->GetInputDim("X"); @@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp { } ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->ShareLoD("X", /*->*/ "XShape"); + + ReshapeOp::InferShape(ctx); } }; @@ -335,38 +341,46 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); -REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, - int64_t, ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, - double, ops::ReshapeGradKernel, int, - ops::ReshapeGradKernel, int64_t, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>, + double, ops::ReshapeKernel<double>, int, + ops::ReshapeKernel<int>, int64_t, + ops::ReshapeKernel<int64_t>); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, + ops::ReshapeGradKernel<float>, double, + ops::ReshapeGradKernel<double>, int, + ops::ReshapeGradKernel<int>, int64_t, + ops::ReshapeGradKernel<int64_t>); REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker, ops::Reshape2GradMaker); REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp); -REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, - int64_t, ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, - double, ops::ReshapeGradKernel, int, - ops::ReshapeGradKernel, int64_t, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>, + double, ops::ReshapeKernel<double>, int, + ops::ReshapeKernel<int>, int64_t, + ops::ReshapeKernel<int64_t>); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, + ops::ReshapeGradKernel<float>, double, + ops::ReshapeGradKernel<double>, int, + ops::ReshapeGradKernel<int>, int64_t, + ops::ReshapeGradKernel<int64_t>); #ifdef PADDLE_WITH_CUDA -REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, - int64_t, ops::ReshapeKernel); -REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, - double, ops::ReshapeGradKernel, int, - ops::ReshapeGradKernel, int64_t, - ops::ReshapeGradKernel); -REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, - int64_t, ops::ReshapeKernel); -REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel, - double, ops::ReshapeGradKernel, int, - ops::ReshapeGradKernel, int64_t, - ops::ReshapeGradKernel); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>, + double, ops::ReshapeKernel<double>, int, + ops::ReshapeKernel<int>, int64_t, + ops::ReshapeKernel<int64_t>); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, + ops::ReshapeGradKernel<float>, double, + ops::ReshapeGradKernel<double>, int, + ops::ReshapeGradKernel<int>, int64_t, + ops::ReshapeGradKernel<int64_t>); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>, + double, ops::ReshapeKernel<double>, int, + ops::ReshapeKernel<int>, int64_t, + ops::ReshapeKernel<int64_t>); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, + ops::ReshapeGradKernel<float>, double, + ops::ReshapeGradKernel<double>, int, + ops::ReshapeGradKernel<int>, int64_t, + ops::ReshapeGradKernel<int64_t>); #endif diff --git a/paddle/fluid/operators/sequence_concat_op.cc b/paddle/fluid/operators/sequence_concat_op.cc index 397a318295..12b53be708 100644 --- a/paddle/fluid/operators/sequence_concat_op.cc +++ b/paddle/fluid/operators/sequence_concat_op.cc @@ -90,11 +90,12 @@ REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, paddle::framework::DefaultGradOpDescMaker<false>); template <typename T> using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>; -REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>); +REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>, + Kernel<int64_t>); REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel, op::SeqConcatGradShapeInferer); template <typename T> using GradKernel = op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>; REGISTER_OP_CPU_KERNEL(sequence_concat_grad, GradKernel<float>, - GradKernel<double>); + GradKernel<double>, GradKernel<int64_t>);