|
|
|
@ -24,30 +24,49 @@ class FCOp : public NetOp {
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: NetOp(type, inputs, outputs, attrs) {
|
|
|
|
|
// mul_out = X * W
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"mul", {{"X", {Input("X")}}, {"Y", {Input("W")}}},
|
|
|
|
|
{{"Out", {Output("mul_out")}}}, {}));
|
|
|
|
|
auto x = Inputs("X");
|
|
|
|
|
auto w = Inputs("W");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x.size(), w.size(),
|
|
|
|
|
"The size of inputs X(%d) should be the same as that of weights W(%d).",
|
|
|
|
|
x.size(), w.size());
|
|
|
|
|
|
|
|
|
|
int n = x.size();
|
|
|
|
|
PADDLE_ENFORCE_GE(n, 1,
|
|
|
|
|
"The size of inputs X(%d) should be no less than 1.", n);
|
|
|
|
|
|
|
|
|
|
// mul_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
|
|
|
|
|
AppendOp(
|
|
|
|
|
framework::OpRegistry::CreateOp("mul", {{"X", {x[0]}}, {"W", {w[0]}}},
|
|
|
|
|
{{"Out", {Output("mul_out")}}}, {}));
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < n; i++) {
|
|
|
|
|
// mul_out = mul_out + X[i] * W[i]
|
|
|
|
|
AppendOp(
|
|
|
|
|
framework::OpRegistry::CreateOp("mul", {{"X", {x[i]}}, {"Y", {w[i]}}},
|
|
|
|
|
{{"Out", {Output("add_out")}}}, {}));
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"add", {{"X", {Output("mul_out")}}, {"Y", {Output("add_out")}}},
|
|
|
|
|
{{"Out", {Output("mul_out")}}}, {}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string add_out_name = "mul_out";
|
|
|
|
|
auto b = Input("b");
|
|
|
|
|
std::string add_out = "mul_out";
|
|
|
|
|
if (b != framework::kEmptyVarName) {
|
|
|
|
|
// add_out = mul_out + b
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"rowwise_add", {{"X", {Output("mul_out")}}, {"b", {Input("b")}}},
|
|
|
|
|
{{"Out", {Output("add_out")}}}, {}));
|
|
|
|
|
add_out_name = "add_out";
|
|
|
|
|
add_out = "add_out";
|
|
|
|
|
} else {
|
|
|
|
|
auto add_out = Output("add_out");
|
|
|
|
|
if (add_out != framework::kEmptyVarName) {
|
|
|
|
|
this->Rename(add_out, framework::kEmptyVarName);
|
|
|
|
|
if (Output("add_out") != framework::kEmptyVarName) {
|
|
|
|
|
this->Rename(Output("add_out"), framework::kEmptyVarName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto activation = GetAttr<std::string>("activation");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(activation,
|
|
|
|
|
{{"X", {Output(add_out_name)}}},
|
|
|
|
|
{{"Y", {Output("Out")}}}, {}));
|
|
|
|
|
auto activation = Attr<std::string>("activation");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
activation, {{"X", {Output(add_out)}}}, {{"Y", {Output("Y")}}}, {}));
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -56,11 +75,11 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
FCOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The 2D input matrix of FC operator.");
|
|
|
|
|
AddInput("W", "The 2D weight matrix of FC operator.");
|
|
|
|
|
AddInput("b", "The 1D bias vector of FC operator");
|
|
|
|
|
AddInput("X", "The 2-D input matrix of FC operator.").AsDuplicable();
|
|
|
|
|
AddInput("W", "The 2-D weight matrix of FC operator.").AsDuplicable();
|
|
|
|
|
AddInput("b", "The 1-D bias vector of FC operator");
|
|
|
|
|
|
|
|
|
|
AddOutput("Out", "The activated output matrix of FC operator");
|
|
|
|
|
AddOutput("Y", "The activated output matrix of FC operator");
|
|
|
|
|
AddOutput("mul_out", "The non-actived output of FC operator, X * W")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("add_out", "The non-actived output of FC operator, X * W + b")
|
|
|
|
@ -78,7 +97,7 @@ learned weights with a matrix multiplication followed by a bias offset
|
|
|
|
|
(optionally).
|
|
|
|
|
|
|
|
|
|
Equation:
|
|
|
|
|
Out = Act(sum_n{X_i * W_i} + b)
|
|
|
|
|
Y = Act(sum_n{X_i * W_i} + b)
|
|
|
|
|
|
|
|
|
|
where X_i is a 2D matrix of size (M x K), usually M is the minibatch size and
|
|
|
|
|
K is the number of features. W_i is also a 2D matrix of size (K x N),
|
|
|
|
|