|
|
|
@ -166,7 +166,7 @@ TEST(Backward, part_of_output_are_not_need) {
|
|
|
|
|
auto backward = f::Backward(*fwd, {"Z"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 2);
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 2UL);
|
|
|
|
|
|
|
|
|
|
auto &fill_zero = *net->ops_[0];
|
|
|
|
|
ASSERT_EQ("fill_zeros_like", fill_zero.type_);
|
|
|
|
@ -184,4 +184,23 @@ TEST(Backward, part_of_output_are_not_need) {
|
|
|
|
|
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
|
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, part_of_input_are_not_need) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
|
|
|
|
|
auto backward = f::Backward(*fwd, {"a"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 1UL);
|
|
|
|
|
|
|
|
|
|
auto &grad_mul = *net->ops_[0];
|
|
|
|
|
ASSERT_EQ(grad_mul.type_, "mul_grad");
|
|
|
|
|
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
|
|
|
|
|
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
|
|
|
|
|
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
|
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
}
|