|
|
|
@ -63,14 +63,22 @@ class FcOp : public NetOp {
|
|
|
|
|
public:
|
|
|
|
|
void Init() override {
|
|
|
|
|
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
|
|
|
|
|
{Output("before_act")}, {}));
|
|
|
|
|
{Output("mul_result")}, {}));
|
|
|
|
|
auto b_name = Input("b");
|
|
|
|
|
std::string before_act = "mul_result";
|
|
|
|
|
if (b_name != EMPTY_VAR_NAME()) {
|
|
|
|
|
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
|
|
|
|
|
{Output("before_act")}, {}));
|
|
|
|
|
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
|
|
|
|
|
{Output("add_result")}, {}));
|
|
|
|
|
before_act = "add_result";
|
|
|
|
|
} else {
|
|
|
|
|
auto out_varname = Output("add_result");
|
|
|
|
|
if (out_varname != EMPTY_VAR_NAME()) {
|
|
|
|
|
this->Rename(out_varname, EMPTY_VAR_NAME());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
|
|
|
|
|
{Output("Out")}, {}));
|
|
|
|
|
|
|
|
|
|
AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")},
|
|
|
|
|
{}));
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -82,7 +90,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X", "x");
|
|
|
|
|
AddInput("W", "w");
|
|
|
|
|
AddInput("b", "b");
|
|
|
|
|
AddOutput("before_act", "before act").SetTemporary();
|
|
|
|
|
AddOutput("mul_result", "").SetTemporary();
|
|
|
|
|
AddOutput("add_result", "").SetTemporary();
|
|
|
|
|
AddOutput("Out", "");
|
|
|
|
|
AddComment("");
|
|
|
|
|
}
|
|
|
|
@ -153,7 +162,7 @@ TEST(Backward, simple_op_grad) {
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", "b"}, {"out", "tmp_forward"}, {});
|
|
|
|
|
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
@ -176,7 +185,7 @@ TEST(Backward, net_fc_backward_normal) {
|
|
|
|
|
TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
|
|
|
|
|
{"out", "tmp_forward"}, {});
|
|
|
|
|
{"mul_result", "add_result", "tmp"}, {});
|
|
|
|
|
ASSERT_NE(fwd, nullptr);
|
|
|
|
|
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
|
|
|
|
|
ASSERT_TRUE(gop->IsNetOp());
|
|
|
|
@ -196,9 +205,9 @@ TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
f::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
|
|
|
|
|
{"hidden0", "tmp0"}, {}));
|
|
|
|
|
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
|
|
|
|
|
{"hidden1", "tmp1"}, {}));
|
|
|
|
|
{"mul_tmp_1", "add_tmp_1", "hidden1"}, {}));
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
@ -235,6 +244,7 @@ TEST(Backward, net_shared_weight) {
|
|
|
|
|
ASSERT_TRUE(bwd->IsNetOp());
|
|
|
|
|
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
|
|
|
|
|
ASSERT_EQ(3UL, bwd_net->ops_.size());
|
|
|
|
|
LOG(INFO) << bwd_net->DebugString();
|
|
|
|
|
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|