|
|
|
@ -152,8 +152,8 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", "b"}, {"out", "tmp_forward"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
@ -175,7 +175,8 @@ TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {});
|
|
|
|
|
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
|
|
|
|
|
{"out", "tmp_forward"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
@ -194,9 +195,10 @@ TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {}));
|
|
|
|
|
net.AddOp(
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
|
|
|
|
|
{"hidden0", "tmp0"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
|
|
|
|
|
{"hidden1", "tmp1"}, {}));
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|