Merge branch 'feature/grad_reg_mechanism_cont2' of https://github.com/reyoung/Paddle into dev_backward_for_op_desc_dev
commit
a598ef5388
@ -1,97 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
|
||||
express or implied. See the License for the specific language governing
|
||||
permissions and limitations under the License. */
|
||||
|
||||
#include "paddle/framework/grad_op_builder.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
enum class OpArgType { IN, OUT };
|
||||
|
||||
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
|
||||
bool is_grad, VariableNameMap* vars) {
|
||||
const auto& src_inout =
|
||||
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
|
||||
auto& dst_inout = *vars;
|
||||
auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
|
||||
const auto& src_arg_list =
|
||||
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
|
||||
for (const auto& arg : src_arg_list) {
|
||||
if (arg.not_in_gradient() && !is_grad) continue;
|
||||
const std::string src_name = arg.name();
|
||||
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
|
||||
dst_inout[dst_name].reserve(src_inout.at(src_name).size());
|
||||
for (auto& var_name : src_inout.at(src_name)) {
|
||||
std::string s = is_grad ? GradVarName(var_name) : var_name;
|
||||
dst_inout[dst_name].emplace_back(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OperatorBase* BuildGradOp(const OperatorBase* op) {
|
||||
auto& info = OpInfoMap::Instance().Get(op->Type());
|
||||
PADDLE_ENFORCE(info.HasGradientOp());
|
||||
|
||||
VariableNameMap inputs;
|
||||
VariableNameMap outputs;
|
||||
TransOpArg(op, OpArgType::IN, false, &inputs); // I
|
||||
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
|
||||
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
|
||||
TransOpArg(op, OpArgType::IN, true, &outputs); // IG
|
||||
|
||||
auto& grad_info = OpInfoMap::Instance().Get(info.grad_op_type_);
|
||||
return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
|
||||
}
|
||||
|
||||
static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type,
|
||||
bool is_grad, OpDescBind* dst_op,
|
||||
const OpArgType& dst_type) {
|
||||
PADDLE_ENFORCE(dst_op != nullptr,
|
||||
"Protobuf desc of gradient op must be initialized first.");
|
||||
const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
|
||||
const auto& src_arg_list =
|
||||
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
|
||||
for (const auto& arg : src_arg_list) {
|
||||
if (arg.not_in_gradient() && !is_grad) continue;
|
||||
const std::string src_name = arg.name();
|
||||
std::vector<std::string> vars = src_type == OpArgType::IN
|
||||
? src_op->Input(src_name)
|
||||
: src_op->Output(src_name);
|
||||
if (is_grad) {
|
||||
for (std::string& var : vars) {
|
||||
var = GradVarName(var);
|
||||
}
|
||||
}
|
||||
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
|
||||
dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars)
|
||||
: dst_op->SetOutput(dst_name, vars);
|
||||
}
|
||||
}
|
||||
|
||||
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) {
|
||||
auto& info = OpInfoMap::Instance().Get(forw_op->Type());
|
||||
PADDLE_ENFORCE(info.HasGradientOp());
|
||||
|
||||
grad_op->SetType(info.grad_op_type_);
|
||||
|
||||
TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN);
|
||||
TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN);
|
||||
TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN);
|
||||
TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT);
|
||||
|
||||
grad_op->SetAttrMap(forw_op->GetAttrMap());
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,28 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_desc.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
OperatorBase* BuildGradOp(const OperatorBase* op);
|
||||
|
||||
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op);
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,186 +0,0 @@
|
||||
#include "paddle/framework/grad_op_builder.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
USE_OP(sum);
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("In1", "a single input");
|
||||
AddInput("In2_mult", "a multiple input").AsDuplicable();
|
||||
AddInput("In3", "another single input");
|
||||
AddOutput("Out1", "a single output");
|
||||
AddOutput("Out2_mult", "a multiple output").AsDuplicable();
|
||||
AddComment("test op with multiple inputs and outputs");
|
||||
}
|
||||
};
|
||||
|
||||
class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
IOIgnoredOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("In1", "a single input");
|
||||
AddInput("In2_mult", "a multiple input").AsDuplicable().NotInGradient();
|
||||
AddInput("In3_mult", "another multiple input").AsDuplicable();
|
||||
AddOutput("Out1_mult", "a multiple output").AsDuplicable();
|
||||
AddOutput("Out2", "a single output").NotInGradient();
|
||||
AddComment("op with inputs and outputs ignored in gradient calculating");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace f = paddle::framework;
|
||||
|
||||
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
|
||||
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
|
||||
|
||||
TEST(GradOpBuilder, MutiInOut) {
|
||||
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
|
||||
"mult_io", {{"In1", {"in1"}},
|
||||
{"In2_mult", {"in2_1", "in2_2", "in2_3"}},
|
||||
{"In3", {"in3"}}},
|
||||
{{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {}));
|
||||
std::shared_ptr<f::OperatorBase> grad_test_op =
|
||||
f::OpRegistry::CreateGradOp(*test_op);
|
||||
|
||||
ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL);
|
||||
EXPECT_EQ(grad_test_op->Input("In1"), "in1");
|
||||
EXPECT_EQ(grad_test_op->Inputs("In2_mult"),
|
||||
std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
|
||||
EXPECT_EQ(grad_test_op->Input("In3"), "in3");
|
||||
EXPECT_EQ(grad_test_op->Input("Out1"), "out1");
|
||||
EXPECT_EQ(grad_test_op->Inputs("Out2_mult"),
|
||||
std::vector<std::string>({"out2_1", "out2_2"}));
|
||||
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")),
|
||||
f::GradVarName("out1"));
|
||||
EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
|
||||
|
||||
ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
|
||||
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
|
||||
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
|
||||
std::vector<std::string>({f::GradVarName("in2_1"),
|
||||
f::GradVarName("in2_2"),
|
||||
f::GradVarName("in2_3")}));
|
||||
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3"));
|
||||
}
|
||||
|
||||
TEST(GradOpBuilder, IOIgnoredInGradient) {
|
||||
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
|
||||
"io_ignored", {{"In1", {"in1"}},
|
||||
{"In2_mult", {"in2_1", "in2_2"}},
|
||||
{"In3_mult", {"in3_1", "in3_2"}}},
|
||||
{{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {}));
|
||||
std::shared_ptr<f::OperatorBase> grad_test_op =
|
||||
f::OpRegistry::CreateGradOp(*test_op);
|
||||
|
||||
// 'In2' and 'Out2' are ignored in gradient calculating
|
||||
ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL);
|
||||
EXPECT_EQ(grad_test_op->Input("In1"), "in1");
|
||||
EXPECT_EQ(grad_test_op->Inputs("In3_mult"),
|
||||
std::vector<std::string>({"in3_1", "in3_2"}));
|
||||
EXPECT_EQ(grad_test_op->Inputs("Out1_mult"),
|
||||
std::vector<std::string>({"out1_1", "out1_2"}));
|
||||
EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
|
||||
EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")),
|
||||
f::GradVarName("out2"));
|
||||
|
||||
ASSERT_EQ(grad_test_op->Outputs().size(), 3UL);
|
||||
EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1"));
|
||||
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
|
||||
EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
|
||||
}
|
||||
|
||||
TEST(GradOpDescBuilder, MutiInOut) {
|
||||
f::OpDescBind *forw_op = new f::OpDescBind();
|
||||
forw_op->SetType("mult_io");
|
||||
forw_op->SetInput("In1", {"in1"});
|
||||
forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"});
|
||||
forw_op->SetInput("In3", {"in3"});
|
||||
forw_op->SetOutput("Out1", {"out1"});
|
||||
forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"});
|
||||
|
||||
f::OpDescBind *grad_op = new f::OpDescBind();
|
||||
f::CompleteGradOpDesc(forw_op, grad_op);
|
||||
|
||||
EXPECT_EQ(grad_op->Type(), "mult_io_grad");
|
||||
ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL);
|
||||
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
|
||||
EXPECT_EQ(grad_op->Input("In2_mult"),
|
||||
std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
|
||||
EXPECT_EQ(grad_op->Input("In3"), std::vector<std::string>({"in3"}));
|
||||
EXPECT_EQ(grad_op->Input("Out1"), std::vector<std::string>({"out1"}));
|
||||
EXPECT_EQ(grad_op->Input("Out2_mult"),
|
||||
std::vector<std::string>({"out2_1", "out2_2"}));
|
||||
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")),
|
||||
std::vector<std::string>({f::GradVarName("out1")}));
|
||||
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
|
||||
|
||||
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
|
||||
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
|
||||
std::vector<std::string>({f::GradVarName("in1")}));
|
||||
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
|
||||
std::vector<std::string>({f::GradVarName("in2_1"),
|
||||
f::GradVarName("in2_2"),
|
||||
f::GradVarName("in2_3")}));
|
||||
EXPECT_EQ(grad_op->Output(f::GradVarName("In3")),
|
||||
std::vector<std::string>({f::GradVarName("in3")}));
|
||||
delete forw_op;
|
||||
delete grad_op;
|
||||
}
|
||||
|
||||
TEST(GradOpDescBuilder, IOIgnoredInGradient) {
|
||||
f::OpDescBind *forw_op = new f::OpDescBind();
|
||||
forw_op->SetType("io_ignored");
|
||||
forw_op->SetInput("In1", {"in1"});
|
||||
forw_op->SetInput("In2_mult", {"in2_1", "in2_2"});
|
||||
forw_op->SetInput("In3_mult", {"in3_1", "in3_2"});
|
||||
forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"});
|
||||
forw_op->SetOutput("Out2", {"out2"});
|
||||
|
||||
f::OpDescBind *grad_op = new f::OpDescBind();
|
||||
f::CompleteGradOpDesc(forw_op, grad_op);
|
||||
|
||||
EXPECT_EQ(grad_op->Type(), "io_ignored_grad");
|
||||
// 'In2' and 'Out2' are ignored in gradient calculating
|
||||
ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL);
|
||||
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
|
||||
EXPECT_EQ(grad_op->Input("In3_mult"),
|
||||
std::vector<std::string>({"in3_1", "in3_2"}));
|
||||
EXPECT_EQ(grad_op->Input("Out1_mult"),
|
||||
std::vector<std::string>({"out1_1", "out1_2"}));
|
||||
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
|
||||
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")),
|
||||
std::vector<std::string>({f::GradVarName("out2")}));
|
||||
|
||||
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
|
||||
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
|
||||
std::vector<std::string>({f::GradVarName("in1")}));
|
||||
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
|
||||
EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")),
|
||||
std::vector<std::string>(
|
||||
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
|
||||
delete forw_op;
|
||||
delete grad_op;
|
||||
}
|
Loading…
Reference in new issue