|
|
|
@ -75,13 +75,13 @@ class FcOp : public operators::NetOp {
|
|
|
|
|
FcOp(const std::string &type, const VarNameMap &inputs,
|
|
|
|
|
const VarNameMap &outputs, const AttributeMap &attrs)
|
|
|
|
|
: NetOp(type, inputs, outputs, attrs) {
|
|
|
|
|
AddOp(OpRegistry::CreateOp("mul",
|
|
|
|
|
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
|
|
|
|
|
{{"Out", {Output("mul_result")}}}, {}));
|
|
|
|
|
AppendOp(OpRegistry::CreateOp("mul",
|
|
|
|
|
{{"X", {Input("X")}}, {"Y", {Input("W")}}},
|
|
|
|
|
{{"Out", {Output("mul_result")}}}, {}));
|
|
|
|
|
auto input_b = Inputs("b");
|
|
|
|
|
std::string before_act = "mul_result";
|
|
|
|
|
if (input_b.size() != 0) {
|
|
|
|
|
AddOp(OpRegistry::CreateOp(
|
|
|
|
|
AppendOp(OpRegistry::CreateOp(
|
|
|
|
|
"rowwise_add", {{"X", {Output("mul_result")}}, {"b", {input_b[0]}}},
|
|
|
|
|
{{"Out", {Output("add_result")}}}, {}));
|
|
|
|
|
before_act = "add_result";
|
|
|
|
@ -92,8 +92,8 @@ class FcOp : public operators::NetOp {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AddOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
|
|
|
|
|
{{"Out", {Output("Out")}}}, {}));
|
|
|
|
|
AppendOp(OpRegistry::CreateOp("sigmoid", {{"X", {Output(before_act)}}},
|
|
|
|
|
{{"Out", {Output("Out")}}}, {}));
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -234,13 +234,13 @@ TEST(Backward, net_fc_backward_not_have_b) {
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
ops::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp(
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {{"X", {"x"}}, {"W", {"W1"}}, {"b", {"b1"}}},
|
|
|
|
|
{{"mul_result", {"mul_tmp_0"}},
|
|
|
|
|
{"add_result", {"add_tmp_0"}},
|
|
|
|
|
{"Out", {"hidden0"}}},
|
|
|
|
|
{}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp(
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {{"X", {"hidden0"}}, {"W", {"W2"}}, {"b", {"b2"}}},
|
|
|
|
|
{{"mul_result", {"mul_tmp_1"}},
|
|
|
|
|
{"add_result", {"add_tmp_1"}},
|
|
|
|
@ -273,10 +273,10 @@ TEST(Backward, net_input_of_network_not_need_grad) {
|
|
|
|
|
|
|
|
|
|
TEST(Backward, net_shared_weight) {
|
|
|
|
|
ops::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
|
|
|
|
|
{{"Out", {"out"}}}, {}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
|
|
|
|
|
{{"Out", {"FinalOut"}}}, {}));
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"x"}}, {"Y", {"w"}}},
|
|
|
|
|
{{"Out", {"out"}}}, {}));
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp("mul", {{"X", {"out"}}, {"Y", {"w"}}},
|
|
|
|
|
{{"Out", {"FinalOut"}}}, {}));
|
|
|
|
|
net.CompleteAddOp();
|
|
|
|
|
|
|
|
|
|
auto bwd = f::Backward(net, {});
|
|
|
|
@ -357,19 +357,19 @@ TEST(Backward, op_part_of_input_are_not_need) {
|
|
|
|
|
|
|
|
|
|
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
|
|
|
|
|
ops::NetOp net;
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp(
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {{"X", {"x1"}}, {"W", {"w1"}}, {"b", {"b1"}}},
|
|
|
|
|
{{"mul_result", {"mul_out1"}},
|
|
|
|
|
{"add_result", {"add_out1"}},
|
|
|
|
|
{"Out", {"out1"}}},
|
|
|
|
|
{}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp(
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {{"X", {"out1"}}, {"W", {"w2"}}, {"b", {"b2"}}},
|
|
|
|
|
{{"mul_result", {"mul_out2"}},
|
|
|
|
|
{"add_result", {"tmp_out2"}},
|
|
|
|
|
{"Out", {"out2"}}},
|
|
|
|
|
{}));
|
|
|
|
|
net.AddOp(f::OpRegistry::CreateOp(
|
|
|
|
|
net.AppendOp(f::OpRegistry::CreateOp(
|
|
|
|
|
"fc", {{"X", {"out2"}}, {"W", {"w3"}}, {"b", {"b3"}}},
|
|
|
|
|
{{"mul_result", {"mul_out3"}},
|
|
|
|
|
{"add_result", {"tmp_out3"}},
|
|
|
|
|