|
|
|
@ -108,6 +108,16 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment("");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class AddOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "x").SetMultiple();
|
|
|
|
|
AddOutput("Y", "y");
|
|
|
|
|
AddComment("");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -123,12 +133,14 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
|
|
|
|
|
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
|
|
|
|
|
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
|
|
|
|
|
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
|
|
|
|
|
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
|
|
|
|
|
|
|
|
|
|
TEST(Backward, simple_grad) {
|
|
|
|
|
TEST(Backward, simple_op_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::OpRegistry::CreateGradOp(*fwd);
|
|
|
|
|
ASSERT_EQ(1, gop->inputs_.size());
|
|
|
|
|
ASSERT_EQ(1UL, gop->inputs_.size());
|
|
|
|
|
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
|
|
|
|
|
ASSERT_EQ("rowwise_add_grad", gop->type_);
|
|
|
|
|
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
|
|
|
|
@ -139,7 +151,7 @@ TEST(Backward, simple_grad) {
|
|
|
|
|
// LOG(INFO) << gop->Output("X" + "@GRAD");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, fc_backward_normal) {
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
@ -161,7 +173,7 @@ TEST(Backward, fc_backward_normal) {
|
|
|
|
|
ASSERT_EQ("mul_grad", d_mul.type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, fc_backward_not_have_b) {
|
|
|
|
|
TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
@ -180,12 +192,12 @@ TEST(Backward, fc_backward_not_have_b) {
|
|
|
|
|
ASSERT_EQ("mul_grad", d_mul.type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, input_layer_not_need_grad) {
|
|
|
|
|
TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {}));
|
|
|
|
|
net.AddOp(
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {}));
|
|
|
|
|
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
@ -198,16 +210,40 @@ TEST(Backward, input_layer_not_need_grad) {
|
|
|
|
|
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
all_output.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Not Generated X
|
|
|
|
|
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
all_output.end());
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(2, bwd_net->ops_.size());
|
|
|
|
|
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
|
|
|
|
|
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
|
|
|
|
|
ASSERT_EQ(3, first_fc_grad->ops_.size());
|
|
|
|
|
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
|
|
|
|
|
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_shared_weight) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
|
|
|
|
|
auto bwd = f::Backward(net, {});
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
|
ASSERT_EQ(3UL, bwd_net->ops_.size());
|
|
|
|
|
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, not_for_network) {
|
|
|
|
|
TEST(Backward, op_register_grad_not_for_network) {
|
|
|
|
|
auto fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
|
|
|
|
|
{{"temporary_index", std::vector<int>{1}}});
|
|
|
|
|
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, all_input_are_not_need) {
|
|
|
|
|
TEST(Backward, op_all_input_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"X", "b"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
@ -215,7 +251,7 @@ TEST(Backward, all_input_are_not_need) {
|
|
|
|
|
ASSERT_TRUE(net->ops_.empty());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, all_output_are_not_need) {
|
|
|
|
|
TEST(Backward, op_all_output_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"Out"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
@ -223,7 +259,7 @@ TEST(Backward, all_output_are_not_need) {
|
|
|
|
|
ASSERT_TRUE(net->ops_.empty());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, part_of_output_are_not_need) {
|
|
|
|
|
TEST(Backward, op_part_of_output_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"Z"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
@ -248,7 +284,7 @@ TEST(Backward, part_of_output_are_not_need) {
|
|
|
|
|
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, part_of_input_are_not_need) {
|
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"a"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|