|
|
|
@ -31,51 +31,76 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SpaceToDepthOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) of SpaceToDepthOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SpaceToDepthOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Out) of SpaceToDepthOp should not be null."));
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument(
|
|
|
|
|
"input should be a 4D tensor"));
|
|
|
|
|
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
|
|
|
|
|
PADDLE_ENFORCE_GT(blocksize, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The blocksize should be Greater than 1"));
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
|
|
|
|
|
"input channel should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize");
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[1], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input channel should be Greater than 0"));
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[2], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Height should be Greater than 0"));
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[3], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Width should be Greater than 0"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[1] % (blocksize * blocksize), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input channel should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
|
|
|
|
|
"input Height should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Height should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
|
|
|
|
|
"input Width should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Width should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize"));
|
|
|
|
|
} else {
|
|
|
|
|
if (x_dims[1] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[1], 0,
|
|
|
|
|
"input channel should be Greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
|
|
|
|
|
"input channel should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input channel should be Greater than 0"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[1] % (blocksize * blocksize), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input channel should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize"));
|
|
|
|
|
}
|
|
|
|
|
if (x_dims[2] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[2], 0,
|
|
|
|
|
"input Height should be Greater than 0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
|
|
|
|
|
"input Height should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Height should be Greater than 0"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[2] % (blocksize), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Height should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (x_dims[3] != -1) {
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
|
|
|
|
|
"input Width should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize");
|
|
|
|
|
PADDLE_ENFORCE_GT(x_dims[3], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Width should be Greater than 0"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims[3] % (blocksize), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"input Width should be divisible of the square of "
|
|
|
|
|
"SpaceToDepthOp blocksize"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -156,9 +181,11 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) shouldn't be null."));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null."));
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|