|
|
|
@ -24,6 +24,15 @@ class FCOp : public NetOp {
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: NetOp(type, inputs, outputs, attrs) {
|
|
|
|
|
PADDLE_ENFORCE(!Inputs("X").empty(),
|
|
|
|
|
"Inputs(X) of FCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(!Inputs("W").empty(),
|
|
|
|
|
"Inputs(W) of FCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(!Outputs("MulOut").empty(),
|
|
|
|
|
"Outputs(MulOut) of FCOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
|
|
|
|
|
"Output(Out) of FCOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x = Inputs("X");
|
|
|
|
|
auto w = Inputs("W");
|
|
|
|
|
auto mul_out = Outputs("MulOut");
|
|
|
|
@ -68,6 +77,10 @@ class FCOp : public NetOp {
|
|
|
|
|
// sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
|
|
|
|
|
auto sum_out = mul_out[0];
|
|
|
|
|
if (n > 1) {
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(SumOut) of FCOp should not be null when the "
|
|
|
|
|
"size of Inputs(X) > 1.");
|
|
|
|
|
|
|
|
|
|
sum_out = Output("SumOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}},
|
|
|
|
|
{{"Out", {sum_out}}}, {}));
|
|
|
|
@ -81,6 +94,10 @@ class FCOp : public NetOp {
|
|
|
|
|
auto b = Input("B");
|
|
|
|
|
auto add_out = sum_out;
|
|
|
|
|
if (b != framework::kEmptyVarName) {
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
Output("AddOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(AddOut) of FCOp should not be null when Input(B) is set.");
|
|
|
|
|
|
|
|
|
|
add_out = Output("AddOut");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}},
|
|
|
|
@ -176,11 +193,5 @@ Activation type can be set to `identity` (default), `sigmoid` or `softmax`.
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
USE_OP(mul);
|
|
|
|
|
USE_OP(rowwise_add);
|
|
|
|
|
USE_NO_KERNEL_OP(identity);
|
|
|
|
|
USE_OP(sigmoid);
|
|
|
|
|
USE_OP(softmax);
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker);
|
|
|
|
|