|
|
|
@ -63,10 +63,10 @@ class FcOp : public NetOp {
|
|
|
|
|
public:
|
|
|
|
|
void Init() override {
|
|
|
|
|
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
|
|
|
|
|
{Output("before_act")}, {}));
|
|
|
|
|
{Output("mul_out")}, {}));
|
|
|
|
|
auto b_name = Input("b");
|
|
|
|
|
if (b_name != EMPTY_VAR_NAME()) {
|
|
|
|
|
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
|
|
|
|
|
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_out"), b_name},
|
|
|
|
|
{Output("before_act")}, {}));
|
|
|
|
|
}
|
|
|
|
|
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
|
|
|
|
@ -82,6 +82,7 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "x");
|
|
|
|
|
AddInput("W", "w");
|
|
|
|
|
AddInput("b", "b");
|
|
|
|
|
AddOutput("mul_out", "mul output").SetTemporary();
|
|
|
|
|
AddOutput("before_act", "before act").SetTemporary();
|
|
|
|
|
AddOutput("Out", "");
|
|
|
|
|
AddComment("");
|
|
|
|
@ -140,6 +141,7 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::OpRegistry::CreateGradOp(*fwd);
|
|
|
|
|
LOG(INFO) << gop->DebugString();
|
|
|
|
|
ASSERT_EQ(1UL, gop->inputs_.size());
|
|
|
|
|
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
|
|
|
|
|
ASSERT_EQ("rowwise_add_grad", gop->type_);
|
|
|
|
@ -151,10 +153,18 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
// LOG(INFO) << gop->Output("X" + "@GRAD");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, simple_net_grad) {
|
|
|
|
|
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
auto gop = f::Backward(*fwd, {});
|
|
|
|
|
LOG(INFO) << gop->DebugString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd =
|
|
|
|
|
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
LOG(INFO) << fwd->DebugString();
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
|
auto net = static_cast<f::NetOp *>(gop.get());
|
|
|
|
|