From c23af80afe4d577f8e7b0c1c4bdd2dd53d5377f1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 30 Sep 2017 16:11:40 -0700 Subject: [PATCH 1/9] Change macro --- paddle/framework/grad_op_desc_maker.h | 4 ++ paddle/framework/op_registry.h | 82 +++++++++++++++------------ 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index cb4d160bd0..672cd7dbaf 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -79,6 +79,7 @@ class GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase { public: + using GradOpDescMakerBase::GradOpDescMakerBase; std::vector operator()() const { return {this->Apply()}; } protected: @@ -86,6 +87,9 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { }; class DefaultGradOpDescMaker : public SingleGradOpDescMaker { + public: + using SingleGradOpDescMaker::SingleGradOpDescMaker; + protected: virtual OpDescBind Apply() const { OpDescBind grad; diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 4ee2c7d275..7db095369e 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -24,14 +24,27 @@ limitations under the License. */ #include "paddle/framework/details/op_registry.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/grad_op_builder.h" +#include "paddle/framework/grad_op_desc_maker.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" namespace paddle { namespace framework { +class Registrar { + public: + // In our design, various kinds of classes, e.g., operators and kernels, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which, + // however, are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_OP macros to + // call this method. So, as long as the callee code calls USE_OP, the global + // registrar variable won't be removed by the linker. + void Touch() {} +}; template -struct OperatorRegistrar { +struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); @@ -70,19 +83,6 @@ class OpRegistry { static std::unique_ptr CreateGradOp(const OperatorBase& op); }; -class Registrar { - public: - // In our design, various kinds of classes, e.g., operators and kernels, - // have their corresponding registry and registrar. The action of - // registration is in the constructor of a global registrar variable, which, - // however, are not used in the code that calls package framework, and would - // be removed from the generated binary file by the linker. To avoid such - // removal, we add Touch to all registrar classes and make USE_OP macros to - // call this method. So, as long as the callee code calls USE_OP, the global - // registrar variable won't be removed by the linker. - void Touch() {} -}; - template class OpRegistrar : public Registrar { public: @@ -138,33 +138,43 @@ class OpKernelRegistrar : public Registrar { __test_global_namespace_##uniq_name##__>::value, \ msg) +#define VA_ARGS(...) , ##__VA_ARGS__ + +#define REGISTER_OPERATOR(op_type, op_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op__##op_type, \ + "REGISTER_OPERATOR must be called in global namespace"); \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \ + __VA_ARGS__)> \ + __op_registrar_##op_type##__(#op_type); \ + int TouchOpRegistrar_##op_type() { \ + __op_registrar_##op_type##__.Touch(); \ + return 0; \ + } + /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ - grad_op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ - class _OpClass_##op_type##_ : public op_class { \ - public: \ - DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ - DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ - }; \ - class _OpGradClass_##op_type##_ : public grad_op_class { \ - public: \ - DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ - DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ - }; \ - static ::paddle::framework::OpRegistrar< \ - _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ - __op_registrar_##op_type##__(#op_type, #grad_op_type); \ - int TouchOpRegistrar_##op_type() { \ - __op_registrar_##op_type##__.Touch(); \ - return 0; \ - } +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ + REGISTER_OPERATOR(grad_op_type, grad_op_class); \ + class _GradOpDescMaker_##grad_op_type##_ \ + : public ::paddle::framework::DefaultGradOpDescMaker { \ + using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \ + \ + protected: \ + virtual std::string GradOpType() const { return #grad_op_type; } \ + }; \ + REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ + op_maker_class) #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ - REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) + REGISTER_OPERATOR(op_type, op_class, op_maker_class) /** * Macro to register OperatorKernel. From d64bedf638d66cc4fedb63bcfd389a1058359798 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 30 Sep 2017 16:44:16 -0700 Subject: [PATCH 2/9] Stash --- paddle/framework/backward.cc | 3 + paddle/framework/backward_test.cc | 31 ++-- paddle/framework/grad_op_builder.cc | 97 ----------- paddle/framework/grad_op_builder.h | 28 ---- paddle/framework/grad_op_builder_test.cc | 201 ----------------------- paddle/framework/op_desc.h | 12 ++ paddle/framework/op_registry.cc | 5 - paddle/framework/op_registry.h | 2 - 8 files changed, 25 insertions(+), 354 deletions(-) delete mode 100644 paddle/framework/grad_op_builder.cc delete mode 100644 paddle/framework/grad_op_builder.h delete mode 100644 paddle/framework/grad_op_builder_test.cc diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 0ec18de5b8..ab2567a25c 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -154,6 +154,9 @@ static std::unique_ptr BackwardRecursive( net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { + OpDescBind fwd_desc; + fwd_desc.SetInput(forwardOp.Inputs()); + std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 6932f5b989..28fc6f9ced 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -159,16 +159,16 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); -TEST(Backward, simple_op_grad) { - auto fwd = f::OpRegistry::CreateOp( - "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); - ASSERT_NE(fwd, nullptr); - auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1UL, gop->Inputs().size()); - ASSERT_EQ("rowwise_add_grad", gop->Type()); - ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); - ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); -} +// TEST(Backward, simple_op_grad) { +// auto fwd = f::OpRegistry::CreateOp( +// "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); +// ASSERT_NE(fwd, nullptr); +// auto gop = f::OpRegistry::CreateGradOp(*fwd); +// ASSERT_EQ(1UL, gop->Inputs().size()); +// ASSERT_EQ("rowwise_add_grad", gop->Type()); +// ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); +// ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); +//} TEST(Backward, simple_op_not_need_grad) { auto fwd = f::OpRegistry::CreateOp( @@ -286,17 +286,6 @@ TEST(Backward, net_shared_weight) { ASSERT_EQ("add", bwd_net->ops_[2]->Type()); } -TEST(Backward, op_register_grad_not_for_network) { - auto fwd = - f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}}, - {{"mul_result", {"mul_out"}}, - {"add_result", {"add_out"}}, - {"Out", {"out1"}}}, - {{"temporary_index", std::vector{0, 1}}}); - - ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); -} - TEST(Backward, op_all_input_are_not_need) { auto fwd = f::OpRegistry::CreateOp( "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc deleted file mode 100644 index 3661ce41be..0000000000 --- a/paddle/framework/grad_op_builder.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either -express or implied. See the License for the specific language governing -permissions and limitations under the License. */ - -#include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace framework { -enum class OpArgType { IN, OUT }; - -static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, - bool is_grad, VariableNameMap* vars) { - const auto& src_inout = - src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); - auto& dst_inout = *vars; - auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); - const auto& src_arg_list = - src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); - for (const auto& arg : src_arg_list) { - if (arg.not_in_gradient() && !is_grad) continue; - const std::string src_name = arg.name(); - std::string dst_name = is_grad ? GradVarName(src_name) : src_name; - dst_inout[dst_name].reserve(src_inout.at(src_name).size()); - for (auto& var_name : src_inout.at(src_name)) { - std::string s = is_grad ? GradVarName(var_name) : var_name; - dst_inout[dst_name].emplace_back(s); - } - } -} - -OperatorBase* BuildGradOp(const OperatorBase* op) { - auto& info = OpInfoMap::Instance().Get(op->Type()); - PADDLE_ENFORCE(info.HasGradientOp()); - - VariableNameMap inputs; - VariableNameMap outputs; - TransOpArg(op, OpArgType::IN, false, &inputs); // I - TransOpArg(op, OpArgType::OUT, false, &inputs); // O - TransOpArg(op, OpArgType::OUT, true, &inputs); // OG - TransOpArg(op, OpArgType::IN, true, &outputs); // IG - - auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_); - return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); -} - -static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type, - bool is_grad, OpDescBind* dst_op, - const OpArgType& dst_type) { - PADDLE_ENFORCE(dst_op != nullptr, - "Protobuf desc of gradient op must be initialized first."); - const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto(); - const auto& src_arg_list = - src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); - for (const auto& arg : src_arg_list) { - if (arg.not_in_gradient() && !is_grad) continue; - const std::string src_name = arg.name(); - std::vector vars = src_type == OpArgType::IN - ? src_op->Input(src_name) - : src_op->Output(src_name); - if (is_grad) { - for (std::string& var : vars) { - var = GradVarName(var); - } - } - std::string dst_name = is_grad ? GradVarName(src_name) : src_name; - dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars) - : dst_op->SetOutput(dst_name, vars); - } -} - -void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) { - auto& info = OpInfoMap::Instance().Get(forw_op->Type()); - PADDLE_ENFORCE(info.HasGradientOp()); - - grad_op->SetType(info.grad_op_type_); - - TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN); - TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN); - TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN); - TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT); - - grad_op->SetAttrMap(forw_op->GetAttrMap()); -} - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/grad_op_builder.h b/paddle/framework/grad_op_builder.h deleted file mode 100644 index b601406061..0000000000 --- a/paddle/framework/grad_op_builder.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/framework/op_desc.h" -#include "paddle/framework/operator.h" - -namespace paddle { -namespace framework { - -OperatorBase* BuildGradOp(const OperatorBase* op); - -void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op); - -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc deleted file mode 100644 index d09892f81b..0000000000 --- a/paddle/framework/grad_op_builder_test.cc +++ /dev/null @@ -1,201 +0,0 @@ -#include "paddle/framework/grad_op_builder.h" -#include -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" - -USE_OP(add); - -namespace paddle { -namespace framework { - -class MutiInOutOpMaker : public OpProtoAndCheckerMaker { - public: - MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable(); - AddInput("In3", "another single input"); - AddOutput("Out1", "a single output"); - AddOutput("Out2_mult", "a multiple output").AsDuplicable(); - AddComment("test op with multiple inputs and outputs"); - } -}; - -class IOIgnoredOpMaker : public OpProtoAndCheckerMaker { - public: - IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("In1", "a single input"); - AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient(); - AddInput("In3_mult", "another multiple input").AsDuplicable(); - AddOutput("Out1_mult", "a multiple output").AsDuplicable(); - AddOutput("Out2", "a single output").NotInGradient(); - AddComment("op with inputs and outputs ignored in gradient calculating"); - } -}; - -} // namespace framework -} // namespace paddle - -namespace f = paddle::framework; - -TEST(GradOpBuilder, AddTwo) { - std::shared_ptr add_op(f::OpRegistry::CreateOp( - "add", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); - std::shared_ptr grad_add_op = - f::OpRegistry::CreateGradOp(*add_op); - EXPECT_EQ(grad_add_op->Inputs().size(), 4UL); - EXPECT_EQ(grad_add_op->Outputs().size(), 2UL); - EXPECT_EQ(grad_add_op->Input("X"), "x"); - EXPECT_EQ(grad_add_op->Input("Y"), "y"); - EXPECT_EQ(grad_add_op->Input("Out"), "out"); - EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out")); - EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x")); - EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); -} - -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); - -TEST(GradOpBuilder, MutiInOut) { - std::shared_ptr test_op(f::OpRegistry::CreateOp( - "mult_io", {{"In1", {"in1"}}, - {"In2_mult", {"in2_1", "in2_2", "in2_3"}}, - {"In3", {"in3"}}}, - {{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {})); - std::shared_ptr grad_test_op = - f::OpRegistry::CreateGradOp(*test_op); - - ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL); - EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In2_mult"), - std::vector({"in2_1", "in2_2", "in2_3"})); - EXPECT_EQ(grad_test_op->Input("In3"), "in3"); - EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); - EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), - std::vector({"out2_1", "out2_2"})); - EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")), - f::GradVarName("out1")); - EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")), - std::vector( - {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - - ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); - EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); - EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), - std::vector({f::GradVarName("in2_1"), - f::GradVarName("in2_2"), - f::GradVarName("in2_3")})); - EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3")); -} - -TEST(GradOpBuilder, IOIgnoredInGradient) { - std::shared_ptr test_op(f::OpRegistry::CreateOp( - "io_ignored", {{"In1", {"in1"}}, - {"In2_mult", {"in2_1", "in2_2"}}, - {"In3_mult", {"in3_1", "in3_2"}}}, - {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {})); - std::shared_ptr grad_test_op = - f::OpRegistry::CreateGradOp(*test_op); - - // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL); - EXPECT_EQ(grad_test_op->Input("In1"), "in1"); - EXPECT_EQ(grad_test_op->Inputs("In3_mult"), - std::vector({"in3_1", "in3_2"})); - EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), - std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), - std::vector( - {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); - EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), - f::GradVarName("out2")); - - ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); - EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); - EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), - std::vector( - {f::GradVarName("in2_1"), f::GradVarName("in2_2")})); - EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")), - std::vector( - {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); -} - -TEST(GradOpDescBuilder, MutiInOut) { - f::OpDescBind *forw_op = new f::OpDescBind(); - forw_op->SetType("mult_io"); - forw_op->SetInput("In1", {"in1"}); - forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"}); - forw_op->SetInput("In3", {"in3"}); - forw_op->SetOutput("Out1", {"out1"}); - forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"}); - - f::OpDescBind *grad_op = new f::OpDescBind(); - f::CompleteGradOpDesc(forw_op, grad_op); - - EXPECT_EQ(grad_op->Type(), "mult_io_grad"); - ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL); - EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); - EXPECT_EQ(grad_op->Input("In2_mult"), - std::vector({"in2_1", "in2_2", "in2_3"})); - EXPECT_EQ(grad_op->Input("In3"), std::vector({"in3"})); - EXPECT_EQ(grad_op->Input("Out1"), std::vector({"out1"})); - EXPECT_EQ(grad_op->Input("Out2_mult"), - std::vector({"out2_1", "out2_2"})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")), - std::vector({f::GradVarName("out1")})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")), - std::vector( - {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - - ASSERT_EQ(grad_op->OutputNames().size(), 3UL); - EXPECT_EQ(grad_op->Output(f::GradVarName("In1")), - std::vector({f::GradVarName("in1")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")), - std::vector({f::GradVarName("in2_1"), - f::GradVarName("in2_2"), - f::GradVarName("in2_3")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In3")), - std::vector({f::GradVarName("in3")})); - delete forw_op; - delete grad_op; -} - -TEST(GradOpDescBuilder, IOIgnoredInGradient) { - f::OpDescBind *forw_op = new f::OpDescBind(); - forw_op->SetType("io_ignored"); - forw_op->SetInput("In1", {"in1"}); - forw_op->SetInput("In2_mult", {"in2_1", "in2_2"}); - forw_op->SetInput("In3_mult", {"in3_1", "in3_2"}); - forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"}); - forw_op->SetOutput("Out2", {"out2"}); - - f::OpDescBind *grad_op = new f::OpDescBind(); - f::CompleteGradOpDesc(forw_op, grad_op); - - EXPECT_EQ(grad_op->Type(), "io_ignored_grad"); - // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL); - EXPECT_EQ(grad_op->Input("In1"), std::vector({"in1"})); - EXPECT_EQ(grad_op->Input("In3_mult"), - std::vector({"in3_1", "in3_2"})); - EXPECT_EQ(grad_op->Input("Out1_mult"), - std::vector({"out1_1", "out1_2"})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")), - std::vector( - {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); - EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")), - std::vector({f::GradVarName("out2")})); - - ASSERT_EQ(grad_op->OutputNames().size(), 3UL); - EXPECT_EQ(grad_op->Output(f::GradVarName("In1")), - std::vector({f::GradVarName("in1")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")), - std::vector( - {f::GradVarName("in2_1"), f::GradVarName("in2_2")})); - EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")), - std::vector( - {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); - delete forw_op; - delete grad_op; -} \ No newline at end of file diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 851a305061..4b001fb964 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -74,6 +74,18 @@ class OpDescBind { return MapKeys(outputs_); } + void SetInput( + const std::unordered_map> &input) { + this->inputs_ = input; + this->need_update_ = true; + } + + void SetOutput( + const std::unordered_map> &output) { + this->outputs_ = output; + this->need_update_ = true; + } + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index b0e85dd49f..0a2b6fd582 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -52,10 +52,5 @@ std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.type(), inputs, outputs, attrs); } -std::unique_ptr OpRegistry::CreateGradOp(const OperatorBase& op) { - PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops"); - return std::unique_ptr(BuildGradOp(&op)); -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 7db095369e..0f377f34cb 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -79,8 +79,6 @@ class OpRegistry { AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); - - static std::unique_ptr CreateGradOp(const OperatorBase& op); }; template From 578a357b616ee188d692764843ae834a449e81c2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 2 Oct 2017 15:12:20 -0700 Subject: [PATCH 3/9] Make compile pass --- paddle/framework/CMakeLists.txt | 4 +--- paddle/framework/backward.cc | 33 +++++++++++++++++++++++++++++---- paddle/framework/op_desc.h | 14 ++++++++------ paddle/framework/op_registry.cc | 6 ++++++ paddle/framework/op_registry.h | 8 +++++--- 5 files changed, 49 insertions(+), 16 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 9140854a96..eb316b4c8c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) -cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) +cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op) py_proto_compile(framework_py_proto SRCS framework.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index ab2567a25c..eb34bc3693 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" +#include "paddle/operators/net_op.h" #include #include @@ -24,6 +25,32 @@ namespace paddle { namespace framework { +static inline std::unique_ptr CreateGradOp( + const OperatorBase& op) { + OpDescBind op_desc; + op_desc.SetInputMap(op.Inputs()); + op_desc.SetOutputMap(op.Outputs()); + op_desc.SetType(op.Type()); + op_desc.SetAttrMap(op.Attrs()); + auto& info = OpInfoMap::Instance().Get(op.Type()); + auto grad_descs = info.grad_op_maker_(op_desc); + std::vector> grad_ops; + grad_ops.reserve(grad_descs.size()); + std::transform( + grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops), + [](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); }); + PADDLE_ENFORCE_GT(grad_ops.size(), 0); + if (grad_ops.size() == 1) { + return std::move(grad_ops[0]); + } else { + auto net_op = new operators::NetOp(); + for (auto& grad_op : grad_ops) { + net_op->AppendOp(std::move(grad_op)); + } + return std::unique_ptr(net_op); + } +} + template static void ForEachVarName(const Map& names, T callback) { for (auto& name : names) { @@ -154,10 +181,8 @@ static std::unique_ptr BackwardRecursive( net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { - OpDescBind fwd_desc; - fwd_desc.SetInput(forwardOp.Inputs()); - - std::unique_ptr grad_op(OpRegistry::CreateGradOp(forwardOp)); + std::unique_ptr grad_op(CreateGradOp(forwardOp)); + PADDLE_ENFORCE(grad_op != nullptr); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( const std::string& grad_input) { diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index ec92d08768..72d7a0379b 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -76,18 +76,22 @@ class OpDescBind { return MapKeys(outputs_); } - void SetInput( - const std::unordered_map> &input) { + void SetInputMap(const VariableNameMap &input) { this->inputs_ = input; this->need_update_ = true; } - void SetOutput( - const std::unordered_map> &output) { + void SetOutputMap(const VariableNameMap &output) { this->outputs_ = output; this->need_update_ = true; } + void Sync(); + + const VariableNameMap &Inputs() const { return inputs_; } + + const VariableNameMap &Outputs() const { return outputs_; } + private: template static std::vector MapKeys(const MapType &map) { @@ -99,8 +103,6 @@ class OpDescBind { return ret_val; } - void Sync(); - OpDesc op_desc_; VariableNameMap inputs_; VariableNameMap outputs_; diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 0a2b6fd582..35f280981b 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -52,5 +52,11 @@ std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.type(), inputs, outputs, attrs); } +std::unique_ptr OpRegistry::CreateOp(OpDescBind* op_desc) { + op_desc->Sync(); + return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(), + op_desc->GetAttrMap()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 0f377f34cb..d14f70008b 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -23,8 +23,8 @@ limitations under the License. */ #include "paddle/framework/attribute.h" #include "paddle/framework/details/op_registry.h" #include "paddle/framework/framework.pb.h" -#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_desc_maker.h" +#include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" #include "paddle/framework/scope.h" @@ -46,15 +46,15 @@ class Registrar { template struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { + std::cerr << "Reg operator " << op_type << std::endl; PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); static_assert(sizeof...(ARGS) != 0, "OperatorRegistrar should be invoked at least by OpClass"); details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); + OpInfoMap::Instance().Insert(op_type, info); } - ~OperatorRegistrar() { OpInfoMap::Instance().Insert(op_type, info); } - const char* op_type; OpInfo info; @@ -79,6 +79,8 @@ class OpRegistry { AttributeMap attrs); static std::unique_ptr CreateOp(const OpDesc& op_desc); + + static std::unique_ptr CreateOp(OpDescBind* op_desc); }; template From ff8766e910a4d9ba1e208458de2719708d6663d3 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 2 Oct 2017 15:50:00 -0700 Subject: [PATCH 4/9] Stash --- paddle/framework/backward_test.cc | 2 ++ paddle/framework/details/op_registry.h | 1 + paddle/framework/op_desc.h | 2 +- paddle/framework/op_info.h | 8 -------- paddle/framework/op_proto_maker.h | 5 ----- paddle/framework/op_registry.cc | 4 +++- 6 files changed, 7 insertions(+), 15 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 28fc6f9ced..85f1dd91ed 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -378,6 +378,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); + std::cerr << grad_fc.DebugString() << std::endl; + EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index daa474e8c5..c805dae7d7 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -85,6 +85,7 @@ struct OpInfoFiller { info->proto_ = new OpProto; info->checker_ = new OpAttrChecker(); auto maker = T(info->proto_, info->checker_); + std::cerr << "Assign Maker " << op_type << std::endl; maker.Validate(); info->proto_->set_type(op_type); PADDLE_ENFORCE( diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h index 72d7a0379b..4c1ada05f0 100644 --- a/paddle/framework/op_desc.h +++ b/paddle/framework/op_desc.h @@ -98,7 +98,7 @@ class OpDescBind { std::vector ret_val; ret_val.reserve(map.size()); std::transform( - map.begin(), map.end(), ret_val.begin(), + map.begin(), map.end(), std::back_inserter(ret_val), [](const typename MapType::value_type &pair) { return pair.first; }); return ret_val; } diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 683476dfd4..ab13dad962 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -42,19 +42,11 @@ struct OpInfo { return *proto_; } - const OpAttrChecker& Checker() const { - PADDLE_ENFORCE_NOT_NULL(checker_, - "Operator Checker has not been registered"); - return *checker_; - } - const OpCreator& Creator() const { PADDLE_ENFORCE_NOT_NULL(creator_, "Operator Creator has not been registered"); return creator_; } - - bool HasGradientOp() const { return !grad_op_type_.empty(); } }; class OpInfoMap { diff --git a/paddle/framework/op_proto_maker.h b/paddle/framework/op_proto_maker.h index 4d55a37db9..a134befd90 100644 --- a/paddle/framework/op_proto_maker.h +++ b/paddle/framework/op_proto_maker.h @@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker { var_->set_intermediate(true); return *this; } - - VariableBuilder& NotInGradient() { - var_->set_not_in_gradient(true); - return *this; - } }; VariableBuilder AddInput(const std::string& name, const std::string& comment); diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 35f280981b..ac6aa8d28e 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -23,7 +23,9 @@ std::unique_ptr OpRegistry::CreateOp( const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, AttributeMap attrs) { auto& info = OpInfoMap::Instance().Get(type); - info.Checker().Check(attrs); + if (info.checker_ != nullptr) { + info.checker_->Check(attrs); + } auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr(op); } From 46c551b2997537a70ea82fd55067fd57cc4c59d5 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 2 Oct 2017 18:27:21 -0700 Subject: [PATCH 5/9] Complete Register Gradient in compile time --- paddle/framework/backward_test.cc | 32 ++++++++----- paddle/framework/details/op_registry.h | 1 - paddle/framework/framework.proto | 1 - paddle/framework/op_info.h | 3 ++ paddle/framework/op_registry.h | 1 - paddle/operators/mean_op.cc | 21 ++++++++- paddle/operators/minus_op.cc | 46 +++++++++---------- paddle/operators/pad_op.cc | 22 +++++++-- paddle/operators/scale_op.cc | 33 ++++++------- .../softmax_with_cross_entropy_op.cc | 45 ++++++++++++------ paddle/operators/sum_op.cc | 41 +++++++++-------- 11 files changed, 152 insertions(+), 94 deletions(-) diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 85f1dd91ed..93688c383b 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -21,24 +21,34 @@ namespace paddle { namespace framework { -using OperatorBase = framework::OperatorBase; -using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker; -using OpProto = framework::OpProto; -using OpAttrChecker = framework::OpAttrChecker; -using Scope = framework::Scope; using DeviceContext = platform::DeviceContext; class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add").NotInGradient(); - AddInput("b", "Bias of Add").NotInGradient(); - AddOutput("Out", "Out of Add").NotInGradient(); + AddInput("X", "Input X of Add"); + AddInput("b", "Bias of Add"); + AddOutput("Out", "Out of Add"); AddComment("Add Op"); } }; +class RowWiseAddGradMaker : public SingleGradOpDescMaker { + public: + using SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + OpDescBind Apply() const override { + OpDescBind grad_op; + grad_op.SetInput(GradVarName("Out"), OutputGrad("Out")); + grad_op.SetOutput(GradVarName("X"), InputGrad("X")); + grad_op.SetOutput(GradVarName("b"), InputGrad("b")); + grad_op.SetType("rowwise_add_grad"); + return grad_op; + } +}; + class MulOpMaker : public OpProtoAndCheckerMaker { public: MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -148,8 +158,9 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, - f::NOP); +REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker, + f::RowWiseAddGradMaker); +REGISTER_OPERATOR(rowwise_add_grad, f::NOP); REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); @@ -378,7 +389,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); - std::cerr << grad_fc.DebugString() << std::endl; EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index c805dae7d7..daa474e8c5 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -85,7 +85,6 @@ struct OpInfoFiller { info->proto_ = new OpProto; info->checker_ = new OpAttrChecker(); auto maker = T(info->proto_, info->checker_); - std::cerr << "Assign Maker " << op_type << std::endl; maker.Validate(); info->proto_->set_type(op_type); PADDLE_ENFORCE( diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index 951c7afbc1..e90a816afa 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -66,7 +66,6 @@ message OpProto { optional bool duplicable = 3 [ default = false ]; optional bool intermediate = 4 [ default = false ]; - optional bool not_in_gradient = 5 [ default = false ]; } // AttrProto describes the C++ type Attribute. diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 2d9568c320..8c2a9178a7 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -17,11 +17,14 @@ #include #include #include + #include "paddle/framework/attribute.h" #include "paddle/framework/op_desc.h" #include "paddle/framework/type_defs.h" #include "paddle/platform/macros.h" +#include "glog/logging.h" + namespace paddle { namespace framework { diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index d14f70008b..da112fa488 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -46,7 +46,6 @@ class Registrar { template struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const char* op_type) : op_type(op_type) { - std::cerr << "Reg operator " << op_type << std::endl; PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), "'%s' is registered more than once.", op_type); static_assert(sizeof...(ARGS) != 0, diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index d799239d4e..0c84cbb5a7 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { MeanOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of mean op"); - AddOutput("Out", "The output of mean op").NotInGradient(); + AddOutput("Out", "The output of mean op"); AddComment(R"DOC( Mean Operator )DOC"); } @@ -52,11 +52,28 @@ class MeanGradOp : public framework::OperatorWithKernel { } }; +class MeanGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + framework::OpDescBind Apply() const override { + framework::OpDescBind grad_op; + grad_op.SetType("mean_grad"); + grad_op.SetInput("X", Input("X")); + grad_op.SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + grad_op.SetOutput(framework::GradVarName("X"), InputGrad("X")); + return grad_op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); + +REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker); +REGISTER_OPERATOR(mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); REGISTER_OP_CPU_KERNEL(mean_grad, diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index ce049d4d7b..1b3ae9a9a6 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker { public: MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The left tensor of minus operator.").NotInGradient(); - AddInput("Y", "The right tensor of minus operator.").NotInGradient(); - AddOutput("Out", "The output tensor of minus operator.").NotInGradient(); + AddInput("X", "The left tensor of minus operator."); + AddInput("Y", "The right tensor of minus operator."); + AddOutput("Out", "The output tensor of minus operator."); AddComment(R"DOC(Minus Operator @@ -64,26 +64,25 @@ or not. But the output only shares the LoD with input `X`. )DOC"); } }; -template -class MinusGradOp : public NetOp { + +class MinusGradMaker : public framework::GradOpDescMakerBase { public: - MinusGradOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - auto out_grad = Input(framework::GradVarName("Out")); - auto x_grad = Output(framework::GradVarName("X")); - auto y_grad = Output(framework::GradVarName("Y")); - - // x_grad = out_grad - AppendOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, - {{"Y", {x_grad}}}, {})); - - framework::AttributeMap scale_attr; - scale_attr["scale"] = static_cast(-1); - AppendOp(framework::OpRegistry::CreateOp("scale", {{"X", {out_grad}}}, - {{"Out", {y_grad}}}, scale_attr)); - CompleteAddOp(false); + using framework::GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector operator()() const override { + std::vector ops; + ops.resize(2); + + ops[0].SetType("scale"); + ops[0].SetInput("X", OutputGrad("Out")); + ops[0].SetOutput("Out", InputGrad("X")); + ops[0].SetAttr("scale", 1.0f); + + ops[1].SetType("scale"); + ops[1].SetInput("X", OutputGrad("Out")); + ops[1].SetOutput("Out", InputGrad("Y")); + ops[1].SetAttr("scale", -1.0f); + return ops; } }; @@ -91,7 +90,6 @@ class MinusGradOp : public NetOp { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, - ops::MinusGradOp); +REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker, ops::MinusGradMaker); REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel); diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 04ebb14f6e..4bd25fa46a 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker { "The input should be a k-D tensor(k > 0 and k < 7)"); AddOutput("Out", "The output of pad op." - "A tensor with the same shape as X.") - .NotInGradient(); + "A tensor with the same shape as X."); AddComment(R"DOC( Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: @@ -111,11 +110,28 @@ class PadOpGrad : public framework::OperatorWithKernel { } }; +class PadOpGradMaker : public framework::SingleGradOpDescMaker { + protected: + framework::OpDescBind Apply() const override { + framework::OpDescBind bind; + bind.SetInput("X", Input("X")); + bind.SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind.SetOutput(framework::GradVarName("X"), InputGrad("X")); + bind.SetAttrMap(Attrs()); + return bind; + } + + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad); + +REGISTER_OPERATOR(pad, ops::PadOp, ops::PadOpMaker, ops::PadOpGradMaker); +REGISTER_OPERATOR(pad_grad, ops::PadOpGrad); REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel); REGISTER_OP_CPU_KERNEL(pad_grad, ops::PadGradKernel); diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index e92501e128..40f0960923 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { public: ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor of scale operator.").NotInGradient(); - AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); + AddInput("X", "The input tensor of scale operator."); + AddOutput("Out", "The output tensor of scale operator."); AddComment(R"DOC(Scale operator The equation is: Out = scale*X @@ -52,21 +52,18 @@ The equation is: Out = scale*X } }; -// The operator to calculate gradients of a scale operator is just the scale -// operator itself. -// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out)) -template -class ScaleGradOp : public NetOp { +class ScaleGradMaker : public framework::SingleGradOpDescMaker { public: - ScaleGradOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - AppendOp(framework::OpRegistry::CreateOp( - "scale", {{"X", {Input(framework::GradVarName("Out"))}}}, - {{"Out", {Output(framework::GradVarName("X"))}}}, - {{"scale", Attr("scale")}})); - CompleteAddOp(false); + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + framework::OpDescBind Apply() const override { + framework::OpDescBind grad_op; + grad_op.SetType("scale"); + grad_op.SetInput("X", OutputGrad("Out")); + grad_op.SetOutput("Out", InputGrad("X")); + grad_op.SetAttr("scale", GetAttr("scale")); + return grad_op; } }; @@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp { namespace ops = paddle::operators; -REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, - ops::ScaleGradOp); +REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, + ops::ScaleGradMaker); REGISTER_OP_CPU_KERNEL(scale, ops::ScaleKernel); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index a76489871f..87dcc3f240 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -27,15 +27,14 @@ class SoftmaxWithCrossEntropyOpMaker AddInput("Logits", "(Tensor, default: Tensor), The unscaled log probabilities " "which is a 2-D tensor with shape [N x K]. N is the batch_size, " - "and K is the class number.") - .NotInGradient(); - AddInput( - "Label", - "(Tensor, default: Tensor), The ground truth which is a 2-D " - "tensor. " - "If softLable is set to 0, Label is a Tensor with shape [N x 1]. " - "If softLable is set to 1, Label is a Tensor " - "with shape [N x K]."); + "and K is the class number."); + AddInput("Label", + "(Tensor, default: Tensor), The ground truth which is a 2-D " + "tensor. " + "If softLable is set to 0, Label is a Tensor with shape [N x " + "1]. " + "If softLable is set to 1, Label is a Tensor " + "with shape [N x K]."); AddOutput( "Softmax", "(Tensor, default: Tensor), A 2-D tensor with shape [N x K]. " @@ -163,15 +162,35 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { } }; +class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + framework::OpDescBind Apply() const override { + framework::OpDescBind grad_op; + grad_op.SetType("softmax_with_cross_entropy_grad"); + grad_op.SetInput("Label", Input("Label")); + grad_op.SetInput("Softmax", Output("Softmax")); + grad_op.SetInput("Loss", Output("Loss")); + grad_op.SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax")); + grad_op.SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + grad_op.SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); + grad_op.SetAttrMap(Attrs()); + return grad_op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, - ops::SoftmaxWithCrossEntropyOpMaker, - softmax_with_cross_entropy_grad, - ops::SoftmaxWithCrossEntropyOpGrad); +REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, + ops::SoftmaxWithCrossEntropyOpMaker, + ops::SoftmaxWithCrossEntropyOpMaker); +REGISTER_OPERATOR(softmax_with_cross_entropy_grad, + ops::SoftmaxWithCrossEntropyOpGrad); REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyKernel); REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 7c422b4770..5ae13492b3 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: SumOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "the input tensors of sum operator.") - .AsDuplicable() - .NotInGradient(); - AddOutput("Out", "the output tensor of sum operator.").NotInGradient(); + AddInput("X", "the input tensors of sum operator.").AsDuplicable(); + AddOutput("Out", "the output tensor of sum operator."); AddComment(R"DOC( Sum the input tensors. @@ -58,23 +56,25 @@ or not. But the output only shares the LoD with the first input. } }; -class SumGradOp : public NetOp { +class SumGradMaker : public framework::GradOpDescMakerBase { public: - SumGradOp(const std::string& type, const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) - : NetOp(type, inputs, outputs, attrs) { - auto& x_grad_names = Outputs(framework::GradVarName("X")); - auto out_grad_name = this->Input(framework::GradVarName("Out")); + using framework::GradOpDescMakerBase::GradOpDescMakerBase; - framework::AttributeMap grad_attrs; - grad_attrs["scale"] = 1.0f; - for (auto& x_grad_name : x_grad_names) { - AppendOp(framework::OpRegistry::CreateOp( - "scale", {{"X", {out_grad_name}}}, {{"Out", {x_grad_name}}}, - grad_attrs)); - } - CompleteAddOp(false); + std::vector operator()() const override { + auto x_grads = InputGrad("X"); + std::vector grad_ops; + grad_ops.reserve(x_grads.size()); + auto og = OutputGrad("Out"); + std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), + [&og](const std::string& x_grad) { + framework::OpDescBind grad_op; + grad_op.SetType("scale"); + grad_op.SetInput("X", og); + grad_op.SetOutput("Out", {x_grad}); + grad_op.SetAttr("scale", 1.0f); + return grad_op; + }); + return grad_ops; } }; @@ -82,5 +82,6 @@ class SumGradOp : public NetOp { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp); + +REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker); REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel); From e119177a8c2372335573d27f273c286580db2bd8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 4 Oct 2017 12:11:34 -0700 Subject: [PATCH 6/9] Use unique_ptr --- paddle/framework/backward.cc | 8 +++--- paddle/framework/backward_test.cc | 20 +++++++-------- paddle/operators/mean_op.cc | 14 +++++------ paddle/operators/minus_op.cc | 25 +++++++++++-------- paddle/operators/pad_op.cc | 20 +++++++-------- paddle/operators/scale_op.cc | 14 +++++------ .../softmax_with_cross_entropy_op.cc | 22 ++++++++-------- paddle/operators/sum_op.cc | 17 +++++++------ 8 files changed, 72 insertions(+), 68 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 22c8c83f13..40390d4150 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -36,9 +36,11 @@ static inline std::unique_ptr CreateGradOp( auto grad_descs = info.grad_op_maker_(op_desc); std::vector> grad_ops; grad_ops.reserve(grad_descs.size()); - std::transform( - grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops), - [](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); }); + std::transform(grad_descs.begin(), grad_descs.end(), + std::back_inserter(grad_ops), + [](const std::unique_ptr& grad_desc) { + return OpRegistry::CreateOp(grad_desc.get()); + }); PADDLE_ENFORCE_GT(grad_ops.size(), 0); if (grad_ops.size() == 1) { return std::move(grad_ops[0]); diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index c88e85f8c4..830d0427fa 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -39,13 +39,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker { using SingleGradOpDescMaker::SingleGradOpDescMaker; protected: - OpDescBind Apply() const override { - OpDescBind grad_op; - grad_op.SetInput(GradVarName("Out"), OutputGrad("Out")); - grad_op.SetOutput(GradVarName("X"), InputGrad("X")); - grad_op.SetOutput(GradVarName("b"), InputGrad("b")); - grad_op.SetType("rowwise_add_grad"); - return grad_op; + std::unique_ptr Apply() const override { + auto grad_op = new OpDescBind(); + grad_op->SetInput(GradVarName("Out"), OutputGrad("Out")); + grad_op->SetOutput(GradVarName("X"), InputGrad("X")); + grad_op->SetOutput(GradVarName("b"), InputGrad("b")); + grad_op->SetType("rowwise_add_grad"); + return std::unique_ptr(grad_op); } }; @@ -147,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { public: SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "the input tensors of sum operator.") - .AsDuplicable() - .NotInGradient(); - AddOutput("Out", "the output tensor of sum operator.").NotInGradient(); + AddInput("X", "the input tensors of sum operator.").AsDuplicable(); + AddOutput("Out", "the output tensor of sum operator."); AddComment(""); } }; diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 0c84cbb5a7..339c089e87 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -57,13 +57,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; protected: - framework::OpDescBind Apply() const override { - framework::OpDescBind grad_op; - grad_op.SetType("mean_grad"); - grad_op.SetInput("X", Input("X")); - grad_op.SetInput(framework::GradVarName("Out"), OutputGrad("Out")); - grad_op.SetOutput(framework::GradVarName("X"), InputGrad("X")); - return grad_op; + std::unique_ptr Apply() const override { + auto* grad_op = new framework::OpDescBind(); + grad_op->SetType("mean_grad"); + grad_op->SetInput("X", Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + return std::unique_ptr(grad_op); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 1b3ae9a9a6..aced8636b9 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -69,19 +69,22 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { public: using framework::GradOpDescMakerBase::GradOpDescMakerBase; - std::vector operator()() const override { - std::vector ops; + std::vector> operator()() + const override { + std::vector> ops; ops.resize(2); - ops[0].SetType("scale"); - ops[0].SetInput("X", OutputGrad("Out")); - ops[0].SetOutput("Out", InputGrad("X")); - ops[0].SetAttr("scale", 1.0f); - - ops[1].SetType("scale"); - ops[1].SetInput("X", OutputGrad("Out")); - ops[1].SetOutput("Out", InputGrad("Y")); - ops[1].SetAttr("scale", -1.0f); + ops[0].reset(new framework::OpDescBind()); + ops[0]->SetType("scale"); + ops[0]->SetInput("X", OutputGrad("Out")); + ops[0]->SetOutput("Out", InputGrad("X")); + ops[0]->SetAttr("scale", 1.0f); + + ops[1].reset(new framework::OpDescBind()); + ops[1]->SetType("scale"); + ops[1]->SetInput("X", OutputGrad("Out")); + ops[1]->SetOutput("Out", InputGrad("Y")); + ops[1]->SetAttr("scale", -1.0f); return ops; } }; diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 4bd25fa46a..9445917739 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -111,18 +111,18 @@ class PadOpGrad : public framework::OperatorWithKernel { }; class PadOpGradMaker : public framework::SingleGradOpDescMaker { - protected: - framework::OpDescBind Apply() const override { - framework::OpDescBind bind; - bind.SetInput("X", Input("X")); - bind.SetInput(framework::GradVarName("Out"), OutputGrad("Out")); - bind.SetOutput(framework::GradVarName("X"), InputGrad("X")); - bind.SetAttrMap(Attrs()); - return bind; - } - public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* bind = new framework::OpDescBind(); + bind->SetInput("X", Input("X")); + bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); + bind->SetAttrMap(Attrs()); + return std::unique_ptr(bind); + } }; } // namespace operators diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index 40f0960923..e225aecc27 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -57,13 +57,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; protected: - framework::OpDescBind Apply() const override { - framework::OpDescBind grad_op; - grad_op.SetType("scale"); - grad_op.SetInput("X", OutputGrad("Out")); - grad_op.SetOutput("Out", InputGrad("X")); - grad_op.SetAttr("scale", GetAttr("scale")); - return grad_op; + std::unique_ptr Apply() const override { + auto *grad_op = new framework::OpDescBind(); + grad_op->SetType("scale"); + grad_op->SetInput("X", OutputGrad("Out")); + grad_op->SetOutput("Out", InputGrad("X")); + grad_op->SetAttr("scale", GetAttr("scale")); + return std::unique_ptr(grad_op); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 87dcc3f240..bc9868874d 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -167,17 +167,17 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; protected: - framework::OpDescBind Apply() const override { - framework::OpDescBind grad_op; - grad_op.SetType("softmax_with_cross_entropy_grad"); - grad_op.SetInput("Label", Input("Label")); - grad_op.SetInput("Softmax", Output("Softmax")); - grad_op.SetInput("Loss", Output("Loss")); - grad_op.SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax")); - grad_op.SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); - grad_op.SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); - grad_op.SetAttrMap(Attrs()); - return grad_op; + std::unique_ptr Apply() const override { + auto* grad_op = new framework::OpDescBind(); + grad_op->SetType("softmax_with_cross_entropy_grad"); + grad_op->SetInput("Label", Input("Label")); + grad_op->SetInput("Softmax", Output("Softmax")); + grad_op->SetInput("Loss", Output("Loss")); + grad_op->SetInput(framework::GradVarName("Softmax"), OutputGrad("Softmax")); + grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); + grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); + grad_op->SetAttrMap(Attrs()); + return std::unique_ptr(grad_op); } }; diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index 5ae13492b3..c701ee8dde 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -60,19 +60,20 @@ class SumGradMaker : public framework::GradOpDescMakerBase { public: using framework::GradOpDescMakerBase::GradOpDescMakerBase; - std::vector operator()() const override { + std::vector> operator()() + const override { auto x_grads = InputGrad("X"); - std::vector grad_ops; + std::vector> grad_ops; grad_ops.reserve(x_grads.size()); auto og = OutputGrad("Out"); std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), [&og](const std::string& x_grad) { - framework::OpDescBind grad_op; - grad_op.SetType("scale"); - grad_op.SetInput("X", og); - grad_op.SetOutput("Out", {x_grad}); - grad_op.SetAttr("scale", 1.0f); - return grad_op; + auto* grad_op = new framework::OpDescBind(); + grad_op->SetType("scale"); + grad_op->SetInput("X", og); + grad_op->SetOutput("Out", {x_grad}); + grad_op->SetAttr("scale", 1.0f); + return std::unique_ptr(grad_op); }); return grad_ops; } From c4effc7d2d1225f15b1ec51ab54753d3ef693de7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 4 Oct 2017 15:34:28 -0700 Subject: [PATCH 7/9] Fix CI Test --- paddle/framework/backward.cc | 3 +- paddle/framework/op_info.h | 7 +++- paddle/framework/op_registry.h | 34 +++++++++---------- paddle/operators/minus_op.cc | 33 +++++++++++------- paddle/operators/pad_op.cc | 1 + .../softmax_with_cross_entropy_op.cc | 9 +++-- 6 files changed, 52 insertions(+), 35 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 40390d4150..9193a1593e 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -33,7 +33,7 @@ static inline std::unique_ptr CreateGradOp( op_desc.SetType(op.Type()); op_desc.SetAttrMap(op.Attrs()); auto& info = OpInfoMap::Instance().Get(op.Type()); - auto grad_descs = info.grad_op_maker_(op_desc); + auto grad_descs = info.GradOpMaker()(op_desc); std::vector> grad_ops; grad_ops.reserve(grad_descs.size()); std::transform(grad_descs.begin(), grad_descs.end(), @@ -49,6 +49,7 @@ static inline std::unique_ptr CreateGradOp( for (auto& grad_op : grad_ops) { net_op->AppendOp(std::move(grad_op)); } + net_op->CompleteAddOp(); return std::unique_ptr(net_op); } } diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 6f87e055b4..968f587b46 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -30,7 +30,6 @@ namespace framework { struct OpInfo { OpCreator creator_; - std::string grad_op_type_; GradOpMakerFN grad_op_maker_; OpProto* proto_{nullptr}; OpAttrChecker* checker_{nullptr}; @@ -51,6 +50,12 @@ struct OpInfo { "Operator Creator has not been registered"); return creator_; } + + const GradOpMakerFN& GradOpMaker() const { + PADDLE_ENFORCE_NOT_NULL(grad_op_maker_, + "Operator GradOpMaker has not been registered."); + return grad_op_maker_; + } }; class OpInfoMap { diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index da112fa488..a4f0144ce8 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -137,23 +137,21 @@ class OpKernelRegistrar : public Registrar { __test_global_namespace_##uniq_name##__>::value, \ msg) -#define VA_ARGS(...) , ##__VA_ARGS__ - -#define REGISTER_OPERATOR(op_type, op_class, ...) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op__##op_type, \ - "REGISTER_OPERATOR must be called in global namespace"); \ - class _OpClass_##op_type##_ : public op_class { \ - public: \ - DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ - DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ - }; \ - static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \ - __VA_ARGS__)> \ - __op_registrar_##op_type##__(#op_type); \ - int TouchOpRegistrar_##op_type() { \ - __op_registrar_##op_type##__.Touch(); \ - return 0; \ +#define REGISTER_OPERATOR(op_type, op_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op__##op_type, \ + "REGISTER_OPERATOR must be called in global namespace"); \ + class _OpClass_##op_type##_ : public op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ + }; \ + static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \ + ##__VA_ARGS__> \ + __op_registrar_##op_type##__(#op_type); \ + int TouchOpRegistrar_##op_type() { \ + __op_registrar_##op_type##__.Touch(); \ + return 0; \ } /** @@ -170,7 +168,7 @@ class OpKernelRegistrar : public Registrar { virtual std::string GradOpType() const { return #grad_op_type; } \ }; \ REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ - op_maker_class) + op_maker_class); #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ REGISTER_OPERATOR(op_type, op_class, op_maker_class) diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index aced8636b9..7057dcbd6e 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -72,19 +72,26 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { std::vector> operator()() const override { std::vector> ops; - ops.resize(2); - - ops[0].reset(new framework::OpDescBind()); - ops[0]->SetType("scale"); - ops[0]->SetInput("X", OutputGrad("Out")); - ops[0]->SetOutput("Out", InputGrad("X")); - ops[0]->SetAttr("scale", 1.0f); - - ops[1].reset(new framework::OpDescBind()); - ops[1]->SetType("scale"); - ops[1]->SetInput("X", OutputGrad("Out")); - ops[1]->SetOutput("Out", InputGrad("Y")); - ops[1]->SetAttr("scale", -1.0f); + auto x_g = InputGrad("X"); + if (!x_g.empty()) { + auto *x_g_op = new framework::OpDescBind(); + x_g_op->SetType("scale"); + x_g_op->SetInput("X", OutputGrad("Out")); + x_g_op->SetOutput("Out", x_g); + x_g_op->SetAttr("scale", 1.0f); + ops.emplace_back(x_g_op); + } + + auto y_g = InputGrad("Y"); + if (!y_g.empty()) { + auto *y_g_op = new framework::OpDescBind(); + y_g_op->SetType("scale"); + y_g_op->SetInput("X", OutputGrad("Out")); + y_g_op->SetOutput("Out", y_g); + y_g_op->SetAttr("scale", -1.0f); + ops.emplace_back(y_g_op); + } + return ops; } }; diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 9445917739..15aa05f266 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -121,6 +121,7 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker { bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); bind->SetAttrMap(Attrs()); + bind->SetType("pad_grad"); return std::unique_ptr(bind); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index bc9868874d..70fe429f59 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -14,6 +14,12 @@ #include "paddle/operators/softmax_with_cross_entropy_op.h" #include +#include + +#define DBG_LINE() \ + do { \ + std::cerr << "Run at " << __LINE__ << std::endl; \ + } while (false) namespace paddle { namespace operators { @@ -187,8 +193,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, - ops::SoftmaxWithCrossEntropyOpMaker, - ops::SoftmaxWithCrossEntropyOpMaker); + ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker); REGISTER_OPERATOR(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyOpGrad); REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, From 2594a50245ae7bfe91f05d3c162af39b2457fa21 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 4 Oct 2017 15:49:43 -0700 Subject: [PATCH 8/9] Polish code --- paddle/framework/backward.cc | 1 - paddle/framework/backward_test.cc | 12 ------------ paddle/framework/op_info.h | 2 -- paddle/framework/op_registry.cc | 1 - paddle/operators/mean_op.cc | 1 - paddle/operators/softmax_with_cross_entropy_op.cc | 5 ----- 6 files changed, 22 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 9193a1593e..3d81dadfc4 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -203,7 +203,6 @@ static std::unique_ptr BackwardRecursive( } } else { std::unique_ptr grad_op(CreateGradOp(forwardOp)); - PADDLE_ENFORCE(grad_op != nullptr); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( const std::string& grad_input) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 830d0427fa..a9b71cd809 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -171,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); -// TEST(Backward, simple_op_grad) { -// auto fwd = f::OpRegistry::CreateOp( -// "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); -// ASSERT_NE(fwd, nullptr); -// auto gop = f::OpRegistry::CreateGradOp(*fwd); -// ASSERT_EQ(1UL, gop->Inputs().size()); -// ASSERT_EQ("rowwise_add_grad", gop->Type()); -// ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); -// ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); -//} - TEST(Backward, simple_op_not_need_grad) { auto fwd = f::OpRegistry::CreateOp( "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); @@ -390,7 +379,6 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); - EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 968f587b46..231f212fa3 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -23,8 +23,6 @@ #include "paddle/framework/type_defs.h" #include "paddle/platform/macros.h" -#include "glog/logging.h" - namespace paddle { namespace framework { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index ac6aa8d28e..4dc83ec8fe 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -55,7 +55,6 @@ std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { } std::unique_ptr OpRegistry::CreateOp(OpDescBind* op_desc) { - op_desc->Sync(); return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(), op_desc->GetAttrMap()); } diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 339c089e87..2332c9546b 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -71,7 +71,6 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { } // namespace paddle namespace ops = paddle::operators; - REGISTER_OPERATOR(mean, ops::MeanOp, ops::MeanOpMaker, ops::MeanGradMaker); REGISTER_OPERATOR(mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 70fe429f59..42c1ba6fdf 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -16,11 +16,6 @@ #include #include -#define DBG_LINE() \ - do { \ - std::cerr << "Run at " << __LINE__ << std::endl; \ - } while (false) - namespace paddle { namespace operators { From ebbbaee0c3d69b26405875216dbc5e798adab4ba Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 4 Oct 2017 16:18:46 -0700 Subject: [PATCH 9/9] Follow comments --- paddle/framework/backward.cc | 4 ++-- paddle/framework/op_registry.cc | 6 +++--- paddle/framework/op_registry.h | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 3d81dadfc4..01ef385fd8 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -39,9 +39,9 @@ static inline std::unique_ptr CreateGradOp( std::transform(grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops), [](const std::unique_ptr& grad_desc) { - return OpRegistry::CreateOp(grad_desc.get()); + return OpRegistry::CreateOp(*grad_desc); }); - PADDLE_ENFORCE_GT(grad_ops.size(), 0); + PADDLE_ENFORCE(!grad_ops.empty()); if (grad_ops.size() == 1) { return std::move(grad_ops[0]); } else { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 4dc83ec8fe..e9d2e55872 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -54,9 +54,9 @@ std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.type(), inputs, outputs, attrs); } -std::unique_ptr OpRegistry::CreateOp(OpDescBind* op_desc) { - return CreateOp(op_desc->Type(), op_desc->Inputs(), op_desc->Outputs(), - op_desc->GetAttrMap()); +std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { + return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(), + op_desc.GetAttrMap()); } } // namespace framework diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index a4f0144ce8..4bf521c48d 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -79,7 +79,7 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); - static std::unique_ptr CreateOp(OpDescBind* op_desc); + static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; template