|
|
|
@ -149,7 +149,6 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::OpRegistry::CreateGradOp(*fwd);
|
|
|
|
|
LOG(INFO) << gop->DebugString();
|
|
|
|
|
ASSERT_EQ(1UL, gop->inputs_.size());
|
|
|
|
|
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
|
|
|
|
|
ASSERT_EQ("rowwise_add_grad", gop->type_);
|
|
|
|
@ -161,18 +160,19 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
// LOG(INFO) << gop->Output("X" + "@GRAD");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, simple_net_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
TEST(Backward, simple_op_not_need_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"x", "b"}, {"out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::Backward(*fwd, {});
|
|
|
|
|
auto gop = f::Backward(*fwd, {"x"});
|
|
|
|
|
LOG(INFO) << gop->DebugString();
|
|
|
|
|
ASSERT_NE(gop->outputs_.find("x" + f::OperatorBase::GRAD_VAR_SUFFIX()),
|
|
|
|
|
gop->outputs_.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
LOG(INFO) << fwd->DebugString();
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(gop.get());
|
|
|
|
|