From 00615ebca2217c9890b1e1212eba1f5d753aa92b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 26 Jul 2017 17:50:13 +0800 Subject: [PATCH 1/4] Refine OpRegistry::AddInput/AddOutput Remove bool argument, use a class to handle that. --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/backward_test.cc | 50 +++++++++++++++++++++++ paddle/framework/op_registry.h | 61 +++++++++++++++------------- paddle/framework/op_registry_test.cc | 5 +-- paddle/framework/operator_test.cc | 4 +- paddle/operators/fc_op.cc | 4 +- 6 files changed, 89 insertions(+), 36 deletions(-) create mode 100644 paddle/framework/backward_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 26d93336b1..66f516a963 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -33,3 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_library(backward SRCS backward.cc DEPS net) +cc_test(backward_test SRCS backward_test.cc DEPS net) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc new file mode 100644 index 0000000000..b2286facfe --- /dev/null +++ b/paddle/framework/backward_test.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" +namespace paddle { +namespace framework { + +class EmptyOp : public OperatorBase { + public: + void InferShape(const std::shared_ptr &scope) const override {} + void Run(const std::shared_ptr &scope, + const platform::DeviceContext &dev_ctx) const override {} +}; + +class RowwiseAddOp : public EmptyOp {}; +class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { + public: + RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input X of Add").IgnoreGradient(); + AddInput("b", "Bias of Add").IgnoreGradient(); + AddOutput("Out", "Out of Add").IgnoreGradient(); + AddComment("Add Op"); + } +}; + +class RowwiseAddGradOp : public EmptyOp {}; +} // namespace framework +} // namespace paddle + +namespace f = paddle::framework; +REGISTER_OP(rowwise_add, f::RowwiseAddOp, f::RowwiseAddOpMaker); +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::RowwiseAddGradOp); + +TEST(Backward, simple_grad) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + ASSERT_NE(fwd, nullptr); +} \ No newline at end of file diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 5bcd7ac927..e4ac8a6e76 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -86,43 +86,46 @@ class OpProtoAndCheckerMaker { } protected: - void AddInput(const std::string& name, const std::string& comment, - bool multiple = false, bool ignore_gradient = false) { + struct VariableBuilder { + VarProto* var_; + std::function on_multiple_; + std::function on_temporary_; + + VariableBuilder& SetMultiple() { + var_->set_multiple(true); + on_multiple_(); + return *this; + } + + VariableBuilder& SetTemporary() { + PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); + var_->set_temporary(true); + on_temporary_(); + return *this; + } + + VariableBuilder& IgnoreGradient() { + var_->set_ignore_gradient(true); + return *this; + } + }; + + VariableBuilder AddInput(const std::string& name, + const std::string& comment) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; - input->set_ignore_gradient(ignore_gradient); - input->set_multiple(multiple); - if (multiple) { - SetHasMultipleInput(); - } - } - - void AddInputs(const std::string& name, const std::string& comment, - bool ignore_gradient = false) { - AddInput(name, comment, true, ignore_gradient); + return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, + nullptr}; } - void AddOutput(const std::string& name, const std::string& comment, - bool temporary = false, bool multiple = false, - bool ignore_gradient = false) { + VariableBuilder AddOutput(const std::string& name, + const std::string& comment) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; - output->set_ignore_gradient(ignore_gradient); - output->set_multiple(multiple); - if (multiple) { - SetHasMultipleOutput(); - } - output->set_temporary(temporary); - if (temporary) { - SetHasTemporaryOutput(); - } - } - - void AddOutputs(const std::string& name, const std::string& comment, - bool temporary = false, bool ignore_gradient = false) { - AddOutput(name, comment, temporary, true, ignore_gradient); + return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, + [=] { this->SetHasTemporaryOutput(); }}; } template diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 2ef781bf86..a534f661af 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -36,9 +36,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("input", "input of cosine op"); - AddOutput("output", "output of cosine op", - /*temporary*/ true); + AddInput("input", "input of cosine op").SetMultiple(); + AddOutput("output", "output of cosine op").SetTemporary(); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 3fae356c3e..839280abbc 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -137,9 +137,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInputs("xs", "inputs of test op"); + AddInput("xs", "inputs of test op").SetMultiple(); AddInput("k", "input of test op"); - AddOutputs("ys", "outputs of test op"); + AddOutput("ys", "outputs of test op").SetMultiple(); AddAttr("scale", "scale of cosine op") .SetDefault(1.0) .LargerThan(0.0); diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc index c4a9f5937f..71ceda9587 100644 --- a/paddle/operators/fc_op.cc +++ b/paddle/operators/fc_op.cc @@ -50,8 +50,8 @@ public: AddInput("b", "the bias of fc operator"); AddOutput("Y", "the output of fc operator"); - AddOutput( - "before_act", "the before activation output of fc operator", true); + AddOutput("before_act", "the before activation output of fc operator") + .SetTemporary(); AddAttr("activation", "The activation key for fc layer") .SetDefault("sigmoid") .InEnum({"sigmoid", "softmax"}); From a2dc9614edfff8ab2a602e1ed605ffdc4155373a Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 26 Jul 2017 18:10:19 +0800 Subject: [PATCH 2/4] Add fill_zeros_like op --- paddle/operators/fill_zeros_like_op.cc | 58 ++++++++++++++++++++++++++ paddle/operators/fill_zeros_like_op.cu | 6 +++ paddle/operators/fill_zeros_like_op.h | 34 +++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 paddle/operators/fill_zeros_like_op.cc create mode 100644 paddle/operators/fill_zeros_like_op.cu create mode 100644 paddle/operators/fill_zeros_like_op.h diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc new file mode 100644 index 0000000000..3df3a2cfab --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fill_zeros_like_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { + +class FillZerosLike : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override { + PADDLE_ENFORCE(inputs.size() == 1, + "Input size of FillZerosLike must be one."); + PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); + PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, + "Outputs of FillZerosLike must all be set."); + outputs[0]->Resize(inputs[0]->dims()); + } +}; + +class FillZerosLikeOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FillZerosLikeOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : framework::OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Src", "The input of fill-zeros-like op."); + AddOutput("Dst", "The varibale will be filled up with zeros."); + AddComment(R"DOC( +Fill up a vriable with zeros. + +The output will have the same size with input. +)DOC") + } +}; +} // namespace operators +} // namespace paddle + +REGISTER_OP(fill_zeros_like, + paddle::operators::FillZerosLikeOp, + paddle::operators::FillZerosLikeOpMaker); +EGISTER_OP_CPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernal); \ No newline at end of file diff --git a/paddle/operators/fill_zeros_like_op.cu b/paddle/operators/fill_zeros_like_op.cu new file mode 100644 index 0000000000..55ad58f4f1 --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.cu @@ -0,0 +1,6 @@ +#include "paddle/framework/op_registry.h" +#include "paddle/operators/fill_zeros_like_op.h" + +REGISTER_OP_GPU_KERNEL( + fill_zeros_like, + paddle::operators::FillZerosLikeKernel); \ No newline at end of file diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h new file mode 100644 index 0000000000..ca44a201f7 --- /dev/null +++ b/paddle/operators/fill_zeros_like_op.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +template +class FillZerosLikeKernel : public framework::OpKernel { +public: + void Compute(const framework::KernelContext& context) const override { + auto* output = context.Output(0)->GetMutable(); + output->mutable_data(context.GetPlace()); + framework::EigenVector::Flatten(*output).setZero(); + } +}; + +} // namespace operators +} // namespace paddle From e32e306821fc8ffd79ccbe6f9c090d1ad217fd56 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 26 Jul 2017 19:37:10 +0800 Subject: [PATCH 3/4] Develop backward building precess of single op --- paddle/framework/backward.cc | 23 +++++++++++++++++++++-- paddle/framework/operator.h | 3 +++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index d8653b5dd6..1531cb53f9 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -12,8 +12,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include -#include +#include "paddle/framework/backward.h" +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { @@ -71,6 +72,24 @@ static std::shared_ptr BackwardImpl( //! TODO(dzh) } else { //! TODO(fjy) + std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); + for (std::string& grad_input : grad_op->inputs_) { + if (no_grad_names.count(grad_input)) { + std::string prefix = grad_input.substr( + 0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); + grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); + std::vector fill_zeros_in = {prefix}; + std::vector fill_zeros_out = {grad_input}; + net.AddOp(OpRegistry::CreateOp("fill_zeros_like", fill_zeros_in, + fill_zeros_out, AttributeMap())); + } + } + for (std::string& grad_output : grad_op->output_) { + if (no_grad_names.count(grad_output)) { + grad_output = OperatorBase::EMPTY_VAR_NAME(); + } + } + net.AddOp(grad_op); } net->CompleteAddOp(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 65fddb6811..c2cd21a080 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -67,6 +67,9 @@ class OperatorBase { /// e.g. Variable "x@GRAD" is the gradient of varibale "x". static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + /// Variables with this suffix are supposed to be filled up with zeros. + static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; } + virtual ~OperatorBase() {} template From 831d4e1c85dedc2bca8cc997ccc612208dc05c38 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 26 Jul 2017 19:37:40 +0800 Subject: [PATCH 4/4] Refining Unittest --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/backward_test.cc | 142 ++++++++++++++++++++++- paddle/framework/grad_op_builder.cc | 19 ++- paddle/framework/grad_op_builder.h | 4 +- paddle/framework/grad_op_builder_test.cc | 2 +- paddle/framework/op_registry.h | 7 +- 6 files changed, 152 insertions(+), 24 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 66f516a963..7febaaa527 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -33,4 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_library(backward SRCS backward.cc DEPS net) -cc_test(backward_test SRCS backward_test.cc DEPS net) +cc_test(backward_test SRCS backward_test.cc DEPS backward) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index b2286facfe..cc00279db5 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -12,8 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/framework/backward.h" #include +#include "paddle/framework/net.h" #include "paddle/framework/op_registry.h" + namespace paddle { namespace framework { @@ -24,10 +27,9 @@ class EmptyOp : public OperatorBase { const platform::DeviceContext &dev_ctx) const override {} }; -class RowwiseAddOp : public EmptyOp {}; -class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { +class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: - RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input X of Add").IgnoreGradient(); AddInput("b", "Bias of Add").IgnoreGradient(); @@ -36,15 +38,143 @@ class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { } }; -class RowwiseAddGradOp : public EmptyOp {}; +class MulOpMaker : public OpProtoAndCheckerMaker { + public: + MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("A", "A"); + AddInput("B", "B"); + AddOutput("Out", "Out"); + AddComment("Mul"); + } +}; + +class SigmoidOpMaker : public OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "X"); + AddOutput("Y", "Y"); + AddComment("Sigmoid"); + } +}; + +class FcOp : public NetOp { + public: + void Init() override { + AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, + {Output("before_act")}, {})); + auto b_name = Input("b"); + if (b_name != EMPTY_VAR_NAME()) { + AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name}, + {Output("before_act")}, {})); + } + AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")}, + {Output("Out")}, {})); + CompleteAddOp(false); + } +}; + +class FcOpMaker : public OpProtoAndCheckerMaker { + public: + FcOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "x"); + AddInput("W", "w"); + AddInput("b", "b"); + AddOutput("before_act", "before act").SetTemporary(); + AddOutput("Out", ""); + AddComment(""); + } +}; + +class ManyOutputOpMaker : public OpProtoAndCheckerMaker { + public: + ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("y", "y"); + AddOutput("z", "z"); + AddComment(""); + } +}; + +class FillZeroOpMaker : public OpProtoAndCheckerMaker { + public: + FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "x"); + AddOutput("out", "out"); + AddComment(""); + } +}; } // namespace framework } // namespace paddle namespace f = paddle::framework; -REGISTER_OP(rowwise_add, f::RowwiseAddOp, f::RowwiseAddOpMaker); -REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::RowwiseAddGradOp); +using EnforceNotMet = paddle::platform::EnforceNotMet; +REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); +REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); +REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); +REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); +REGISTER_OP(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); +REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); +REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); TEST(Backward, simple_grad) { auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); ASSERT_NE(fwd, nullptr); + auto gop = f::OpRegistry::CreateGradOp(*fwd); + ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); + ASSERT_EQ("rowwise_add_grad", gop->type_); + ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); + ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]); + + // LOG(INFO) << gop->Output("X" + "@GRAD"); +} + +TEST(Backward, not_for_network) { + auto fwd = + f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, + {{"temporary_index", std::vector{1}}}); + ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); +} + +TEST(Backward, all_input_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"X", "b"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, all_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); + auto backward = f::Backward(*fwd, {"Out"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_TRUE(net->ops_.empty()); +} + +TEST(Backward, part_of_output_are_not_need) { + auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); + auto backward = f::Backward(*fwd, {"Z"}); + ASSERT_TRUE(backward->IsNetOp()); + auto net = static_cast(backward.get()); + ASSERT_EQ(net->ops_.size(), 2); + + auto &fill_zero = *net->ops_[0]; + ASSERT_EQ("fill_zeros_like", fill_zero.type_); + ASSERT_EQ(1, fill_zero.inputs_.size()); + ASSERT_EQ("Z", fill_zero.inputs_[0]); + ASSERT_EQ(1, fill_zero.outputs_.size()); + ASSERT_EQ("Z@ZERO", fill_zero.outputs_[0]); + + auto &d_many_out = *net->ops_[1]; + ASSERT_EQ("many_output_op_grad", d_many_out.type_); + ASSERT_EQ(1 + 2 + 2, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("Z@ZERO", d_many_out.Input("z@GRAD")); } \ No newline at end of file diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 6235be75f2..dd686cc782 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -20,7 +20,7 @@ namespace framework { OperatorBase* GradOpBuilder::Build() { BuildOpInOutArgList(); - std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); + std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); grad_op->type_ = grad_op_type; CompleteGradOp(grad_op); @@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, } void GradOpBuilder::BuildOpInOutArgList() { - const OpProto& op_proto = OpRegistry::protos().at(op_->type_); - const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); + const OpProto& op_proto = OpRegistry::protos().at(op_.type_); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_)); const std::vector& in_format = - op_->attrs_.count("input_format") - ? op_->GetAttr>("input_format") + op_.attrs_.count("input_format") + ? op_.GetAttr>("input_format") : std::vector(); const std::vector& out_format = - op_->attrs_.count("output_format") - ? op_->GetAttr>("output_format") + op_.attrs_.count("output_format") + ? op_.GetAttr>("output_format") : std::vector(); for (const auto& var : op_proto.inputs()) { arg_list_.emplace_back( @@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } (*varmap)[var_name] = idx++; size_t pre_sz = in_out.size(); - auto base_it = - arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin(); std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::back_inserter(in_out)); if (is_grad) { @@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, } void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { - grad_op->attrs_ = op_->attrs_; + grad_op->attrs_ = op_.attrs_; grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("output_format"); VarIndexMap* grad_varmap = new VarIndexMap(); diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h index 2ecf39479b..cc7a76f372 100644 --- a/paddle/framework/grad_op_builder.h +++ b/paddle/framework/grad_op_builder.h @@ -29,7 +29,7 @@ class GradOpBuilder { using VarIndexMap = std::unordered_map; public: - GradOpBuilder(const OperatorBase* op) : op_(op) {} + GradOpBuilder(const OperatorBase& op) : op_(op) {} OperatorBase* Build(); private: @@ -40,7 +40,7 @@ class GradOpBuilder { std::vector& format, VarIndexMap* varmap, int& idx, bool is_grad) const; void CompleteGradOp(OperatorBase* grad_op) const; - const OperatorBase* op_; + const OperatorBase& op_; std::vector> arg_list_; }; diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index 288a7841cd..e9cf3b9798 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -11,7 +11,7 @@ namespace framework { TEST(GradOpBuilder, AddTwo) { std::shared_ptr add_op( OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); - std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); + std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(*add_op); EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); EXPECT_EQ(grad_add_op->Input("X"), "x"); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index e4ac8a6e76..cee20b1112 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -303,11 +303,10 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } - static std::shared_ptr CreateGradOp( - std::shared_ptr op) { - PADDLE_ENFORCE(!op->IsNetOp(), + static std::shared_ptr CreateGradOp(const OperatorBase& op) { + PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - GradOpBuilder builder(op.get()); + GradOpBuilder builder(op); std::shared_ptr grad_op(builder.Build()); grad_op->Init(); return grad_op;