|
|
|
@ -20,17 +20,27 @@ class PixelShuffleOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of PixelShuffleOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of PixelShuffleOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(X) of PixelShuffleOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Output(Out) of PixelShuffleOp should not be null."));
|
|
|
|
|
|
|
|
|
|
auto input_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
input_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
|
|
|
|
|
input_dims.size()));
|
|
|
|
|
|
|
|
|
|
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0,
|
|
|
|
|
"Upscale_factor should devide the number of channel");
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_dims[1] % (upscale_factor * upscale_factor), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The square of upscale_factor[%u] should divide the "
|
|
|
|
|
"number of channel[%u]",
|
|
|
|
|
input_dims[1], upscale_factor * upscale_factor));
|
|
|
|
|
|
|
|
|
|
auto output_dims = input_dims;
|
|
|
|
|
output_dims[0] = input_dims[0];
|
|
|
|
@ -57,7 +67,8 @@ class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.SetDefault(1)
|
|
|
|
|
.AddCustomChecker([](const int& upscale_factor) {
|
|
|
|
|
PADDLE_ENFORCE_GE(upscale_factor, 1,
|
|
|
|
|
"upscale_factor should be larger than 0.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"upscale_factor should be larger than 0."));
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -95,13 +106,19 @@ class PixelShuffleGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@Grad) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"Output(X@Grad) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::NotFound("Input(Out@Grad) should not be null"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput(framework::GradVarName("X")), true,
|
|
|
|
|
platform::errors::NotFound("Output(X@Grad) should not be null"));
|
|
|
|
|
|
|
|
|
|
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
do_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input should be a 4-D tensor of format [N, C, H, W], but got %u.",
|
|
|
|
|
do_dims.size()));
|
|
|
|
|
|
|
|
|
|
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
|
|
|
|
|
|
|
|
|
|