diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index d89160e18b..f6848a800f 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -782,6 +782,8 @@ class SquareDoubleGradMaker DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference, {framework::GradVarName("Out"), framework::GradVarName("X")}); +DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInference, + {"DDX", "DDOut"}); class PowGradOpDescMaker : public framework::SingleGradOpDescMaker { public: @@ -896,7 +898,8 @@ REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, ops::ReluDoubleGradMaker); REGISTER_OPERATOR( relu_grad_grad, - ops::ActivationOpDoubleGrad2::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>, + ops::ActivationDoubleGradOpInplaceInference); REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor); @@ -921,7 +924,8 @@ REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, ops::LeakyReluDoubleGradMaker); REGISTER_OPERATOR( leaky_relu_grad_grad, - ops::ActivationOpDoubleGrad2::FwdDeps()>); + ops::ActivationOpDoubleGrad2::FwdDeps()>, + ops::ActivationDoubleGradOpInplaceInference); REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor); @@ -945,7 +949,9 @@ REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, ops::SqrtDoubleGradMaker); REGISTER_OPERATOR( sqrt_grad_grad, - ops::ActivationOpDoubleGrad::FwdDeps()>); + ops::ActivationOpDoubleGrad::FwdDeps()>, + ops::ActivationDoubleGradOpInplaceInference); + REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); REGISTER_OP_CPU_KERNEL( sqrt_grad_grad, ops::SqrtDoubleGradKernel::FwdDeps()>); + ops::ActivationOpDoubleGrad::FwdDeps()>, + ops::ActivationDoubleGradOpInplaceInference); REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor, SquareGradFunctor); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 316fb00eb9..ea19dcd3ab 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -1437,15 +1437,17 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); auto out = framework::EigenVector::Flatten(detail::Ref(Out)); - if (ddOut) { - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); - ddout.device(*d) = ddx * static_cast(0.5) / out; - } + // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx + // calculate dy first, so ddy can inplace ddx if (dOut) { auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); dout.device(*d) = dx * ddx * static_cast(-1) / out; } + if (ddOut) { + auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + ddout.device(*d) = ddx * static_cast(0.5) / out; + } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; @@ -1459,15 +1461,17 @@ struct SquareGradGradFunctor : public BaseActivationFunctor { auto* d = dev.eigen_device(); auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); auto x = framework::EigenVector::Flatten(detail::Ref(X)); - if (ddOut) { - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); - ddout.device(*d) = ddx * static_cast(2) * x; - } + // square GradGrad: ddy=2x*ddx, dx=2dy*ddx + // calculate dx first, so ddy can inplace ddx if (dX) { auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); dx.device(*d) = ddx * static_cast(2) * dout; } + if (ddOut) { + auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + ddout.device(*d) = ddx * static_cast(2) * x; + } } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; diff --git a/paddle/fluid/operators/elementwise/CMakeLists.txt b/paddle/fluid/operators/elementwise/CMakeLists.txt index 37be1116ab..94886066ca 100644 --- a/paddle/fluid/operators/elementwise/CMakeLists.txt +++ b/paddle/fluid/operators/elementwise/CMakeLists.txt @@ -2,3 +2,5 @@ include(operators) register_operators() cc_test(test_elementwise_add_op_inplace SRCS test_elementwise_add_op_inplace.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) +cc_test(test_elementwise_div_grad_grad SRCS test_elementwise_div_grad_grad.cc DEPS op_registry elementwise_div_op scope device_context enforce executor) +cc_test(test_elementwise_add_grad_grad SRCS test_elementwise_add_grad_grad.cc DEPS op_registry elementwise_add_op scope device_context enforce executor) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index bf12d8a1a6..fd93aa441e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -54,7 +54,9 @@ REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpExplicitGrad, ops::ElementwiseGradNoBufVarsInference, ops::ElementwiseAddDoubleGradDescMaker); REGISTER_OPERATOR(elementwise_add_grad_grad, - ops::ElementwiseOpDoubleGradWithoutDXDY); + ops::ElementwiseOpDoubleGradWithoutDXDY, + ops::ElementwiseDoubleGradOpInplace, + ops::ElementwiseDoubleGradNoBufVarsInference); REGISTER_OP_CPU_KERNEL( elementwise_add, diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 8320272b4b..15b4bff0b7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -36,4 +36,6 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, - ops::ElementwiseAddDoubleGradKernel); + ops::ElementwiseAddDoubleGradKernel, + ops::ElementwiseAddDoubleGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 6689823d4a..f025a84520 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -80,7 +80,8 @@ REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp, REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad, ops::ElementwiseDivDoubleGradDescMaker); -REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad); +REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad, + ops::ElementwiseDivDoubleGradOpInplace); REGISTER_OP_CPU_KERNEL( elementwise_div, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index b38f84845b..4cd17b94e5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -37,6 +37,8 @@ REGISTER_OP_CUDA_KERNEL( elementwise_div_grad_grad, ops::ElementwiseDivDoubleGradKernel, + ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel { GetDoubleGradSafeTensor(ctx, Out, ddX, &ddX_safe); GetDoubleGradSafeTensor(ctx, Y, ddY, &ddY_safe); + // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y + // dY = Out * dX * ddY / Y - dX * ddX / Y + // dOut = - dX * ddY + // To save memory, (1) dout can be used as 'tmp' tensor, (2) ddout can + // inplace ddx + Tensor tmp; if (dOut) { - // dOut = - dX * ddY - default_elementwise_mul(ctx, dX, &ddY_safe, dOut); - auto& place = - *ctx.template device_context().eigen_device(); - auto dout = framework::EigenVector::Flatten(*dOut); - dout.device(place) = static_cast(-1) * dout; + tmp = *dOut; + } else { + auto& dev_ctx = ctx.template device_context(); + tmp = ctx.AllocateTmpTensor(Out->dims(), dev_ctx); } - if (dY) { // dX_div_Y = dX / Y; - auto& dev_ctx = ctx.template device_context(); - Tensor dX_div_Y = - ctx.AllocateTmpTensor(Out->dims(), dev_ctx); + Tensor dX_div_Y = tmp; ElementwiseComputeEx, DeviceContext, T>( ctx, dX, Y, axis, DivFunctor(), &dX_div_Y); @@ -179,14 +180,25 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel { if (ddOut) { // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y - default_elementwise_mul(ctx, Out, &ddY_safe, ddOut); + default_elementwise_mul(ctx, Out, &ddY_safe, &tmp); ElementwiseComputeEx, DeviceContext, T>( - ctx, &ddX_safe, ddOut, 0, SubFunctor(), ddOut); + ctx, &ddX_safe, &tmp, 0, SubFunctor(), &tmp); ElementwiseComputeEx, DeviceContext, T>( - ctx, ddOut, Y, axis, DivFunctor(), ddOut); + ctx, &tmp, Y, axis, DivFunctor(), ddOut); + } + + if (dOut) { + // dOut = - dX * ddY + default_elementwise_mul(ctx, dX, &ddY_safe, dOut); + auto& place = + *ctx.template device_context().eigen_device(); + auto dout = framework::EigenVector::Flatten(*dOut); + dout.device(place) = static_cast(-1) * dout; } } }; +DECLARE_INPLACE_OP_INFERER(ElementwiseDivDoubleGradOpInplace, {"DDX", "DDOut"}); + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 0f6af96ff3..69900e0637 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -77,7 +77,8 @@ REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpGradDescMaker); REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad, ops::ElementwiseMulDoubleGradDescMaker); -REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad); +REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, + ops::ElementwiseMulDoubleGradOpInplace); REGISTER_OP_CPU_KERNEL( elementwise_mul, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index d18c7e66f1..d3c0dcb409 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -94,4 +94,6 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel); + ops::ElementwiseMulDoubleGradKernel, + ops::ElementwiseMulDoubleGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 105707b803..aa8bfdf9d1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -146,37 +146,48 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel { if (ddout) ddout->mutable_data(ctx.GetPlace()); - // dx = dout * ddy - // dy = dout * ddx Tensor ddx_safe, ddy_safe; GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX(), - MulGradDY()); + // dx = dout * ddy + // dy = dout * ddx // ddout = ddx * y + x * ddy + // change computation sequence to save memory, so ddout can inplace ddx and + // dx can be used as 'tmp' tensor + // (1) dx = x * ddy + // (2) dy = dout * ddx + // (3) ddout = ddx * y + // (4) ddout = ddout + dx + // (5) dx = dout *ddy if (ddout) { - if (ddx && ddy) { - Tensor ddout_tmp; - ddout_tmp.mutable_data(ddout->dims(), ctx.GetPlace()); - - default_elementwise_mul(ctx, ddx, y, ddout); - default_elementwise_mul(ctx, x, ddy, &ddout_tmp); - - auto& place = - *ctx.template device_context().eigen_device(); - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - } else { - if (ddx) default_elementwise_mul(ctx, ddx, y, ddout); - if (ddy) default_elementwise_mul(ctx, x, ddy, ddout); - } + // use dx to save memory, other than alloc tmp tensor + Tensor* ddout_tmp = dx; + + default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); + int axis = ctx.Attr("axis"); + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + ElemwiseGradCompute, MulGradDY>( + ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, + MulGradDX(), MulGradDY()); + default_elementwise_mul(ctx, &ddx_safe, y, ddout); + + auto& place = + *ctx.template device_context().eigen_device(); + auto ddout_t = framework::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + default_elementwise_mul(ctx, dout, &ddy_safe, dx); } } }; +DECLARE_INPLACE_OP_INFERER(ElementwiseMulDoubleGradOpInplace, {"DDX", "DDOut"}, + {"X", framework::GradVarName("X")}, + {"Y", framework::GradVarName("Y")}); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index aa6375d300..da678c5ee4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -264,7 +264,18 @@ class ElementwiseOpDoubleGradWithoutDXDY framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("DOut")->type(); + framework::proto::VarType::Type input_data_type; + if (ctx.HasInput("DDX") == false) { + PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true, + "Input(DDY) should not be null"); + input_data_type = ctx.Input("DDY")->type(); + } else if (ctx.HasInput("DDY") == false) { + PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true, + "Input(DDX) should not be null"); + input_data_type = ctx.Input("DDX")->type(); + } else { + input_data_type = ctx.Input("DDX")->type(); + } #ifdef PADDLE_WITH_MKLDNN if (platform::CanMKLDNNBeUsed(ctx)) { @@ -321,8 +332,11 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseOpInplace, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace, {framework::GradVarName("Out"), framework::GradVarName("X")}); +DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplace, {"DDX", "DDOut"}); DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y"); +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference, + "Y", "DOut"); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index b1ec10ea86..b3003092c7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -54,7 +54,9 @@ REGISTER_OPERATOR(elementwise_sub_grad, ops::ElementwiseOpExplicitGrad, ops::ElementwiseGradNoBufVarsInference, ops::ElementwiseSubDoubleGradDescMaker); REGISTER_OPERATOR(elementwise_sub_grad_grad, - ops::ElementwiseOpDoubleGradWithoutDXDY); + ops::ElementwiseOpDoubleGradWithoutDXDY, + ops::ElementwiseDoubleGradOpInplace, + ops::ElementwiseDoubleGradNoBufVarsInference); REGISTER_OP_CPU_KERNEL( elementwise_sub, diff --git a/paddle/fluid/operators/elementwise/test_elementwise_add_grad_grad.cc b/paddle/fluid/operators/elementwise/test_elementwise_add_grad_grad.cc new file mode 100644 index 0000000000..532084f492 --- /dev/null +++ b/paddle/fluid/operators/elementwise/test_elementwise_add_grad_grad.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +USE_OP(elementwise_add); + +namespace paddle { +namespace operators { + +template +class TestElementwiseAddGradGradWithoutDout + : public TestElementwiseOpGradGrad { + public: + TestElementwiseAddGradGradWithoutDout(const platform::Place &place, + const framework::DDim &dims) + : TestElementwiseOpGradGrad("elementwise_add_grad_grad", place, dims, + {"Y", "DOut", "DDY"}, {"DDOut"}) {} + + using TestElementwiseOpGradGrad::feed_datas_; + using TestElementwiseOpGradGrad::expected_outs_; + using TestElementwiseOpGradGrad::dims_; + void ComputeExpectedOuts() override { + size_t numel = static_cast(framework::product(dims_)); + std::vector dy(numel); + std::vector ddout(numel); + for (size_t i = 0; i < numel; ++i) { + // ddOut = ddX + ddY = ddY if ddX empty + ddout[i] = feed_datas_["DDY"][i]; + } + expected_outs_["DDOut"] = ddout; + } + + std::unique_ptr CreateTestOp() override { + auto op = framework::OpRegistry::CreateOp( + this->op_type_, {{"Y", {"Y"}}, {"DOut", {"DOut"}}, {"DDY", {"DDY"}}}, + {{"DDOut", {"DDOut"}}}, {{"use_mkldnn", false}, {"axis", 0}}); + return op; + } +}; + +TEST(test_elementwise_add_grad_grad_without_ddx, cpu_place) { + framework::DDim dims({32, 64}); + platform::CPUPlace p; + TestElementwiseAddGradGradWithoutDout test(p, dims); + ASSERT_TRUE(test.Check()); +} +#ifdef PADDLE_WITH_CUDA +TEST(test_elementwise_add_grad_grad_without_ddx, gpu_place) { + framework::DDim dims({32, 64}); + platform::CUDAPlace p(0); + TestElementwiseAddGradGradWithoutDout test(p, dims); + ASSERT_TRUE(test.Check()); +} +#endif + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc b/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc new file mode 100644 index 0000000000..e1f893dd2b --- /dev/null +++ b/paddle/fluid/operators/elementwise/test_elementwise_div_grad_grad.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +USE_OP(elementwise_div); + +namespace paddle { +namespace operators { + +template +class TestElementwiseDivGradGradWithoutDout + : public TestElementwiseOpGradGrad { + public: + TestElementwiseDivGradGradWithoutDout(const platform::Place &place, + const framework::DDim &dims) + : TestElementwiseOpGradGrad("elementwise_div_grad_grad", place, dims, + {"Y", "Out", "DDX", "DDY", "DX"}, + {"Y@GRAD", "DDOut"}) {} + + using TestElementwiseOpGradGrad::feed_datas_; + using TestElementwiseOpGradGrad::expected_outs_; + using TestElementwiseOpGradGrad::dims_; + void ComputeExpectedOuts() override { + size_t numel = static_cast(framework::product(dims_)); + std::vector dy(numel); + std::vector ddout(numel); + for (size_t i = 0; i < numel; ++i) { + // dY(Y@GRAD) = Out * dX * ddY / Y - dX * ddX / Y + dy[i] = (feed_datas_["DX"][i] / feed_datas_["Y"][i]) * + (feed_datas_["Out"][i] * feed_datas_["DDY"][i] - + feed_datas_["DDX"][i]); + // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y + ddout[i] = (feed_datas_["DDX"][i] - + feed_datas_["Out"][i] * feed_datas_["DDY"][i]) / + (feed_datas_["Y"][i]); + } + expected_outs_["Y@GRAD"] = dy; + expected_outs_["DDOut"] = ddout; + } + + std::unique_ptr CreateTestOp() override { + auto op = framework::OpRegistry::CreateOp( + this->op_type_, {{"Y", {"Y"}}, + {"Out", {"Out"}}, + {"DDX", {"DDX"}}, + {"DDY", {"DDY"}}, + {"DX", {"DX"}}}, + {{"Y@GRAD", {"Y@GRAD"}}, {"DDOut", {"DDOut"}}}, + {{"use_mkldnn", false}, {"axis", 0}}); + return op; + } +}; + +TEST(test_elementwise_div_grad_grad_without_dout, cpu_place) { + framework::DDim dims({32, 64}); + platform::CPUPlace p; + TestElementwiseDivGradGradWithoutDout test(p, dims); + ASSERT_TRUE(test.Check()); +} + +#ifdef PADDLE_WITH_CUDA +TEST(test_elementwise_div_grad_grad_without_dout, gpu_place) { + framework::DDim dims({32, 64}); + platform::CUDAPlace p(0); + TestElementwiseDivGradGradWithoutDout test(p, dims); + ASSERT_TRUE(test.Check()); +} +#endif + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h b/paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h new file mode 100644 index 0000000000..c7ce5142c0 --- /dev/null +++ b/paddle/fluid/operators/elementwise/test_elementwise_op_grad_grad.h @@ -0,0 +1,151 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +// currently, this test class only support same dims +template +class TestElementwiseOpGradGrad { + public: + TestElementwiseOpGradGrad(const std::string &op_type, + const platform::Place &place, + const framework::DDim &dims, + const std::vector &inputs, + const std::vector &outputs) + : op_type_(op_type), + place_(place), + dims_(dims), + inputs_(inputs), + outputs_(outputs) {} + + void InitVarInScope(std::string var_name) { + in_out_tensors_[var_name] = + scope_.Var(var_name)->template GetMutable(); + in_out_tensors_[var_name]->Resize(dims_); + in_out_tensors_[var_name]->template mutable_data(place_); + } + + void InitFeedData(std::string var_name, size_t size) { + // generate random data + std::uniform_real_distribution dist(static_cast(10.0), + static_cast(20.0)); + std::mt19937 engine; + std::vector data(size); + for (size_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } + feed_datas_[var_name] = data; + } + + void Setup() { + size_t numel = static_cast(framework::product(dims_)); + // init vars in scope and feed inputs + for (auto in_name : inputs_) { + InitVarInScope(in_name); + InitFeedData(in_name, numel); + } + for (auto out_name : outputs_) { + InitVarInScope(out_name); + } + + // feeding: copy data to tensor, out tensor don't need init + auto bytes = sizeof(T) * numel; + for (auto &in_name : inputs_) { + auto dst = in_out_tensors_[in_name]->template data(); + auto src = feed_datas_[in_name].data(); + auto src_place = platform::CPUPlace(); + if (platform::is_cpu_place(place_)) { + auto dst_place = boost::get(place_); + memory::Copy(dst_place, dst, src_place, src, bytes); + } else if (platform::is_gpu_place(place_)) { +#ifdef PADDLE_WITH_CUDA + auto dst_place = boost::get(place_); + memory::Copy(dst_place, dst, src_place, src, bytes, nullptr); +#else + PADDLE_THROW("Not compiled with cuda"); +#endif + } + } + + // calculate expected outputs + ComputeExpectedOuts(); + } + + bool Check() { + Setup(); + auto op = CreateTestOp(); + op->Run(scope_, place_); + platform::DeviceContextPool::Instance().Get(place_)->Wait(); + framework::LoDTensor cpu_out; + PADDLE_ENFORCE_EQ(scope_.kids().empty(), true, "scope has child scopes"); + + // get outputs from scope and compare them with expected_outs + bool all_equal = true; + for (auto &out_name : outputs_) { + auto &out_tensor = + scope_.FindVar(out_name)->template Get(); + if (platform::is_gpu_place(place_)) { + framework::TensorCopySync(out_tensor, platform::CPUPlace(), &cpu_out); + } else { + cpu_out = out_tensor; + } + auto *out_ptr = cpu_out.data(); + size_t numel = static_cast(framework::product(dims_)); + auto is_equal = + std::equal(out_ptr, out_ptr + numel, expected_outs_[out_name].data()); + if (!is_equal) { + all_equal = false; + break; + } + } + return all_equal; + } + + virtual std::unique_ptr CreateTestOp() = 0; + virtual void ComputeExpectedOuts() = 0; + virtual ~TestElementwiseOpGradGrad() {} + + protected: + std::string op_type_; + platform::Place place_; + framework::DDim dims_; + std::vector inputs_; + std::vector outputs_; + std::map in_out_tensors_; + std::map> feed_datas_; + std::map> expected_outs_; + framework::Scope scope_; +}; + +} // namespace operators +} // namespace paddle