|
|
|
@ -22,37 +22,59 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
void FusionRepeatedFCReluOp::InferShape(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of FusionRepeatedFCReluOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusionRepeatedFCRelu");
|
|
|
|
|
auto sz = ctx->Inputs("W").size();
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
sz, 1UL, "Inputs(W) of FusionRepeatedFCReluOp should larger than 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->Inputs("Bias").size(), sz,
|
|
|
|
|
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
|
|
|
|
|
"equal to inputs size.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->Outputs("ReluOut").size(), sz - 1,
|
|
|
|
|
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
|
|
|
|
|
"be equal to inputs size -1.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FusionRepeatedFCReluOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_GT(sz, 1UL, platform::errors::InvalidArgument(
|
|
|
|
|
"Inputs(W) of FusionRepeatedFCReluOp should "
|
|
|
|
|
"be greater than 1, but received value is %d.",
|
|
|
|
|
sz));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->Inputs("Bias").size(), sz,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Size of inputs(Bias) of FusionRepeatedFCReluOp should be "
|
|
|
|
|
"equal to inputs size %d, but received value is %d.",
|
|
|
|
|
sz, ctx->Inputs("Bias").size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->Outputs("ReluOut").size(), sz - 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Size of output(ReluOut) of FusionRepeatedFCReluOp should "
|
|
|
|
|
"be equal to inputs size minus one %d, but received value is %d",
|
|
|
|
|
sz - 1, ctx->Outputs("ReluOut").size()));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
|
|
|
|
|
"FusionRepeatedFCRelu");
|
|
|
|
|
|
|
|
|
|
auto i_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(i_dims.size(), 2, "Input shape size should be 2");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
i_dims.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input shape size should be 2, but received value is %d.",
|
|
|
|
|
i_dims.size()));
|
|
|
|
|
|
|
|
|
|
auto w_dims = ctx->GetInputsDim("W");
|
|
|
|
|
auto b_dims = ctx->GetInputsDim("Bias");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), b_dims.size(),
|
|
|
|
|
"Shape size of weight and bias should be equal");
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims.size(), sz,
|
|
|
|
|
"Shape size of weight and bias should be equal");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape size of weight and bias should be equal, but "
|
|
|
|
|
"weight size is %d, bias size is %d.",
|
|
|
|
|
w_dims.size(), b_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(i_dims[1], w_dims[0][0],
|
|
|
|
|
"inpute width should be equal with weight height");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input width should be equal to weight height, but "
|
|
|
|
|
"input width is %d, weight height is %d.",
|
|
|
|
|
i_dims[1], w_dims[0][0]));
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < sz; ++i) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(w_dims[i].size(), 2,
|
|
|
|
|
"Every weight shape size should be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(b_dims[i]), w_dims[i][1],
|
|
|
|
|
"The length of Bias must be equal with w_dims[1].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Every weight shape size should be 2., but received "
|
|
|
|
|
"w_dims[%d].size() = %d.",
|
|
|
|
|
i, w_dims[i].size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::product(b_dims[i]), w_dims[i][1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The length of Bias must be equal with w_dims[1], but received "
|
|
|
|
|
"product(b_dims[%d]) = %d, w_dims[%d][1] = %d.",
|
|
|
|
|
i, framework::product(b_dims[i]), i, w_dims[i][1]));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", {i_dims[0], w_dims[sz - 1][1]});
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|