|
|
@ -313,6 +313,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
|
|
|
|
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
|
|
|
|
auto backward = f::Backward(*fwd, {"a"});
|
|
|
|
auto backward = f::Backward(*fwd, {"a"});
|
|
|
@ -334,6 +335,7 @@ TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
ASSERT_EQ(grad_mul.Input("B"), "b");
|
|
|
|
ASSERT_EQ(grad_mul.Input("B"), "b");
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out"), "out");
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out"), "out");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
f::NetOp net;
|
|
|
|
f::NetOp net;
|
|
|
@ -343,33 +345,35 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
{"mul_out2", "tmp_out2", "out2"}, {}));
|
|
|
|
{"mul_out2", "tmp_out2", "out2"}, {}));
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
|
|
|
|
{"mul_out3", "tmp_out3", "out3"}, {}));
|
|
|
|
{"mul_out3", "tmp_out3", "out3"}, {}));
|
|
|
|
net.CompleteAddOp(false);
|
|
|
|
net.CompleteAddOp();
|
|
|
|
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
|
|
|
|
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
|
|
|
|
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
|
|
|
|
|
|
|
|
EXPECT_EQ(bwd_net->ops_[0]->type_, "fc_grad");
|
|
|
|
|
|
|
|
EXPECT_EQ(bwd_net->ops_[1]->type_, "");
|
|
|
|
|
|
|
|
EXPECT_EQ(bwd_net->ops_[2]->type_, "");
|
|
|
|
|
|
|
|
|
|
|
|
auto &grad_fc = *bwd_net->ops_[0];
|
|
|
|
auto &grad_fc = *bwd_net->ops_[0];
|
|
|
|
ASSERT_EQ(grad_fc.type_, "fc_grad");
|
|
|
|
EXPECT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL);
|
|
|
|
ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 3UL + 3UL);
|
|
|
|
EXPECT_EQ(grad_fc.outputs_.size(), 3UL);
|
|
|
|
ASSERT_EQ(grad_fc.outputs_.size(), 3UL);
|
|
|
|
EXPECT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
ASSERT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
EXPECT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
EXPECT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
ASSERT_EQ(grad_fc.Input("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
EXPECT_EQ(grad_fc.Input("mul_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
"mul_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
ASSERT_EQ(grad_fc.Input("add_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
EXPECT_EQ(grad_fc.Input("add_result" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
"tmp_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
"tmp_out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
EXPECT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(grad_fc.Input("X"), "out2");
|
|
|
|
EXPECT_EQ(grad_fc.Input("X"), "out2");
|
|
|
|
ASSERT_EQ(grad_fc.Input("W"), "w3");
|
|
|
|
EXPECT_EQ(grad_fc.Input("W"), "w3");
|
|
|
|
ASSERT_EQ(grad_fc.Input("b"), "b3");
|
|
|
|
EXPECT_EQ(grad_fc.Input("b"), "b3");
|
|
|
|
ASSERT_EQ(grad_fc.Input("mul_result"), "mul_out3");
|
|
|
|
EXPECT_EQ(grad_fc.Input("mul_result"), "mul_out3");
|
|
|
|
ASSERT_EQ(grad_fc.Input("add_result"), "tmp_out3");
|
|
|
|
EXPECT_EQ(grad_fc.Input("add_result"), "tmp_out3");
|
|
|
|
ASSERT_EQ(grad_fc.Input("Out"), "out3");
|
|
|
|
EXPECT_EQ(grad_fc.Input("Out"), "out3");
|
|
|
|
}
|
|
|
|
}
|
|
|
|