|
|
|
@ -23,23 +23,35 @@ class FSPOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FSPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of FSPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of FSPOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp_op");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp_op");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fsp_op");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
x_dims.size() == 4,
|
|
|
|
|
"The Input(X) must have shape [batch_size, channel, height, width].");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
y_dims.size() == 4,
|
|
|
|
|
"The Input(Y) must have shape [batch_size, channel, height, width].");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
(x_dims[2] == y_dims[2]) && (x_dims[3] == y_dims[3]),
|
|
|
|
|
"The Input(X) and Input(Y) should have the same height and width.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size(), 4UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(X) must have shape [batch_size, channel, height, width]."
|
|
|
|
|
"Now the dimension of 'X' is %d.",
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
y_dims.size(), 4UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(Y) must have shape [batch_size, channel, height, width]."
|
|
|
|
|
"Now the dimension of 'Y' is %d.",
|
|
|
|
|
y_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[2], y_dims[2],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(X)(%d) and Input(Y)(%d) should have the same height.",
|
|
|
|
|
x_dims[2], y_dims[2]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[3], y_dims[3],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(X)(%d) and Input(Y)(%d) should have the same width.",
|
|
|
|
|
x_dims[3], y_dims[3]));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", {x_dims[0], x_dims[1], y_dims[1]});
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|