|
|
|
@ -24,30 +24,30 @@ class FCOp : public NetOp {
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: NetOp(type, inputs, outputs, attrs) {
|
|
|
|
|
// mul_out = X * W
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"mul", {{"X", {Input("X")}}, {"Y", {Input("W")}}},
|
|
|
|
|
{{"Out", {Output("mul_out")}}}, {}));
|
|
|
|
|
|
|
|
|
|
std::string add_out_name = "mul_out";
|
|
|
|
|
auto b = Input("b");
|
|
|
|
|
if (b != framework::kEmptyVarName) {
|
|
|
|
|
// add_out = mul_out + b
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"rowwise_add", {{"X", {Output("mul_out")}}, {"b", {Input("b")}}},
|
|
|
|
|
{{"Out", {Output("add_out")}}}, {}));
|
|
|
|
|
add_out_name = "add_out";
|
|
|
|
|
} else {
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"identity", {{"X", {Output("mul_out")}}},
|
|
|
|
|
{{"Out", {Output("add_out")}}}, {}));
|
|
|
|
|
auto add_out = Output("add_out");
|
|
|
|
|
if (add_out != framework::kEmptyVarName) {
|
|
|
|
|
this->Rename(add_out, framework::kEmptyVarName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto activation = GetAttr<std::string>("activation");
|
|
|
|
|
if (activation == "identity") {
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(activation,
|
|
|
|
|
{{"X", {Output("add_out")}}},
|
|
|
|
|
{{"Out", {Output("Out")}}}, {}));
|
|
|
|
|
} else {
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(activation,
|
|
|
|
|
{{"X", {Output("add_out")}}},
|
|
|
|
|
{{"Y", {Output("Out")}}}, {}));
|
|
|
|
|
}
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(activation,
|
|
|
|
|
{{"X", {Output(add_out_name)}}},
|
|
|
|
|
{{"Y", {Output("Out")}}}, {}));
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|