|
|
|
@ -128,6 +128,7 @@ TEST(Backward, simple_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::OpRegistry::CreateGradOp(*fwd);
|
|
|
|
|
ASSERT_EQ(1, gop->inputs_.size());
|
|
|
|
|
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
|
|
|
|
|
ASSERT_EQ("rowwise_add_grad", gop->type_);
|
|
|
|
|
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
|
|
|
|
@ -138,6 +139,67 @@ TEST(Backward, simple_grad) {
|
|
|
|
|
// LOG(INFO) << gop->Output("X" + "@GRAD");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(gop.get());
|
|
|
|
|
|
|
|
|
|
ASSERT_NO_THROW(net->DebugString());
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(3UL, net->ops_.size());
|
|
|
|
|
|
|
|
|
|
f::OperatorBase &d_sigmoid = *net->ops_[0];
|
|
|
|
|
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
|
|
|
|
|
|
|
|
|
|
f::OperatorBase &d_add = *net->ops_[1];
|
|
|
|
|
ASSERT_EQ("rowwise_add_grad", d_add.type_);
|
|
|
|
|
|
|
|
|
|
f::OperatorBase &d_mul = *net->ops_[2];
|
|
|
|
|
ASSERT_EQ("mul_grad", d_mul.type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, fc_backward_not_have_b) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(gop.get());
|
|
|
|
|
|
|
|
|
|
ASSERT_NO_THROW(net->DebugString());
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(2UL, net->ops_.size());
|
|
|
|
|
|
|
|
|
|
f::OperatorBase &d_sigmoid = *net->ops_[0];
|
|
|
|
|
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
|
|
|
|
|
|
|
|
|
|
f::OperatorBase &d_mul = *net->ops_[1];
|
|
|
|
|
ASSERT_EQ("mul_grad", d_mul.type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, input_layer_not_need_grad) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {}));
|
|
|
|
|
net.AddOp(
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {}));
|
|
|
|
|
|
|
|
|
|
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
|
|
|
|
|
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
|
|
|
|
|
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
|
|
|
|
|
|
|
|
|
|
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
|
|
|
|
|
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
all_output.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, not_for_network) {
|
|
|
|
|
auto fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
|
|
|
|
@ -166,7 +228,7 @@ TEST(Backward, part_of_output_are_not_need) {
|
|
|
|
|
auto backward = f::Backward(*fwd, {"Z"});
|
|
|
|
|
ASSERT_TRUE(backward->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(backward.get());
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 2);
|
|
|
|
|
ASSERT_EQ(net->ops_.size(), 2UL);
|
|
|
|
|
|
|
|
|
|
auto &fill_zero = *net->ops_[0];
|
|
|
|
|
ASSERT_EQ("fill_zeros_like", fill_zero.type_);
|
|
|
|
|