|
|
|
@ -133,6 +133,7 @@ TEST(GradOpDescBuilder, MutiInOut) {
|
|
|
|
|
f::OpDescBind *grad_op = new f::OpDescBind();
|
|
|
|
|
f::CompleteGradOpDesc(forw_op, grad_op);
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_op->Type(), "mult_io_grad");
|
|
|
|
|
ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL);
|
|
|
|
|
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
|
|
|
|
|
EXPECT_EQ(grad_op->Input("In2_mult"),
|
|
|
|
@ -156,4 +157,45 @@ TEST(GradOpDescBuilder, MutiInOut) {
|
|
|
|
|
f::GradVarName("in2_3")}));
|
|
|
|
|
EXPECT_EQ(grad_op->Output(f::GradVarName("In3")),
|
|
|
|
|
std::vector<std::string>({f::GradVarName("in3")}));
|
|
|
|
|
delete forw_op;
|
|
|
|
|
delete grad_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(GradOpDescBuilder, IOIgnoredInGradient) {
|
|
|
|
|
f::OpDescBind *forw_op = new f::OpDescBind();
|
|
|
|
|
forw_op->SetType("io_ignored");
|
|
|
|
|
forw_op->SetInput("In1", {"in1"});
|
|
|
|
|
forw_op->SetInput("In2_mult", {"in2_1", "in2_2"});
|
|
|
|
|
forw_op->SetInput("In3_mult", {"in3_1", "in3_2"});
|
|
|
|
|
forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"});
|
|
|
|
|
forw_op->SetOutput("Out2", {"out2"});
|
|
|
|
|
|
|
|
|
|
f::OpDescBind *grad_op = new f::OpDescBind();
|
|
|
|
|
f::CompleteGradOpDesc(forw_op, grad_op);
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(grad_op->Type(), "io_ignored_grad");
|
|
|
|
|
// 'In2' and 'Out2' are ignored in gradient calculating
|
|
|
|
|
ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL);
|
|
|
|
|
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
|
|
|
|
|
EXPECT_EQ(grad_op->Input("In3_mult"),
|
|
|
|
|
std::vector<std::string>({"in3_1", "in3_2"}));
|
|
|
|
|
EXPECT_EQ(grad_op->Input("Out1_mult"),
|
|
|
|
|
std::vector<std::string>({"out1_1", "out1_2"}));
|
|
|
|
|
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")),
|
|
|
|
|
std::vector<std::string>(
|
|
|
|
|
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
|
|
|
|
|
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")),
|
|
|
|
|
std::vector<std::string>({f::GradVarName("out2")}));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
|
|
|
|
|
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
|
|
|
|
|
std::vector<std::string>({f::GradVarName("in1")}));
|
|
|
|
|
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
|
|
|
|
|
std::vector<std::string>(
|
|
|
|
|
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
|
|
|
|
|
EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")),
|
|
|
|
|
std::vector<std::string>(
|
|
|
|
|
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
|
|
|
|
|
delete forw_op;
|
|
|
|
|
delete grad_op;
|
|
|
|
|
}
|