|
|
|
@ -66,22 +66,25 @@ class FCOp : public NetOp {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
|
|
|
|
|
auto sum_out = mul_out[0];
|
|
|
|
|
if (n > 1) {
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {mul_out}}}, {{"Out", {Output("SumOut")}}}, {}));
|
|
|
|
|
sum_out = Output("SumOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}},
|
|
|
|
|
{{"Out", {sum_out}}}, {}));
|
|
|
|
|
} else {
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"identity", {{"X", {mul_out[0]}}}, {{"Y", {Output("SumOut")}}}, {}));
|
|
|
|
|
if (Output("SumOut") != framework::kEmptyVarName) {
|
|
|
|
|
this->Rename(Output("SumOut"), framework::kEmptyVarName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// add_out = sum_out + b
|
|
|
|
|
auto b = Input("B");
|
|
|
|
|
std::string add_out = "SumOut";
|
|
|
|
|
auto add_out = sum_out;
|
|
|
|
|
if (b != framework::kEmptyVarName) {
|
|
|
|
|
add_out = "AddOut";
|
|
|
|
|
add_out = Output("AddOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"rowwise_add", {{"X", {Output("SumOut")}}, {"b", {Input("B")}}},
|
|
|
|
|
{{"Out", {Output(add_out)}}}, {}));
|
|
|
|
|
"rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}},
|
|
|
|
|
{{"Out", {add_out}}}, {}));
|
|
|
|
|
} else {
|
|
|
|
|
if (Output("AddOut") != framework::kEmptyVarName) {
|
|
|
|
|
this->Rename(Output("AddOut"), framework::kEmptyVarName);
|
|
|
|
@ -89,8 +92,8 @@ class FCOp : public NetOp {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto activation = Attr<std::string>("activation");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
activation, {{"X", {Output(add_out)}}}, {{"Y", {Output("Out")}}}, {}));
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(activation, {{"X", {add_out}}},
|
|
|
|
|
{{"Y", {Output("Out")}}}, {}));
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|