Add Pixel shuffle OP (#15782)
* add pixel_shuffle op * add pixel_shuffle op, test=develop * rewrite code, test=develop * delete useless comment, test=develop * Refine pixel_shuffle_op and unit testing * refine code,test=develop * refine .cu,test=develop * fix unittest,test=develop * Fix unit testing test=develop * resolve conflict, test=develop * fix test, test=develop * fix API, test=develop * fix test datatype bug,test=develop * polish comments,test=develop * add API,test=develop * test=develop * Add Pixel_Shuffle OP,test=develop * support python3,test=develop * add include memory to travis CI bug,test=developdevel
parent
38382f8e27
commit
229dc93277
@ -0,0 +1,135 @@
|
||||
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/pixel_shuffle_op.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PixelShuffleOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of PixelShuffleOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of PixelShuffleOp should not be null.");
|
||||
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
||||
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
|
||||
|
||||
PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0,
|
||||
"Upscale_factor should devide the number of channel");
|
||||
|
||||
auto output_dims = input_dims;
|
||||
output_dims[0] = input_dims[0];
|
||||
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
|
||||
output_dims[2] = input_dims[2] * upscale_factor;
|
||||
output_dims[3] = input_dims[3] * upscale_factor;
|
||||
ctx->SetOutputDim("Out", output_dims);
|
||||
}
|
||||
};
|
||||
|
||||
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput(
|
||||
"X",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(Tensor, default Tensor<float>), the output of "
|
||||
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
|
||||
AddAttr<int>("upscale_factor",
|
||||
"the factor to increase spatial resolution by.")
|
||||
.SetDefault(1)
|
||||
.AddCustomChecker([](const int& upscale_factor) {
|
||||
PADDLE_ENFORCE_GE(upscale_factor, 1,
|
||||
"upscale_factor should be larger than 0.");
|
||||
});
|
||||
|
||||
AddComment(R"DOC(
|
||||
Pixel Shuffle operator
|
||||
This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
|
||||
to a tensor of shape :math:`(C, H \times r, W \times r)`.
|
||||
|
||||
This is useful for implementing efficient sub-pixel convolution
|
||||
with a stride of :math:`1/r`.
|
||||
|
||||
Please refer to the paper:
|
||||
`Real-Time Single Image and Video Super-Resolution Using an Efficient
|
||||
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_
|
||||
by Shi et. al (2016) for more details.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class PixelShuffleGradMaker : public framework::SingleGradOpDescMaker {
|
||||
public:
|
||||
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
||||
|
||||
std::unique_ptr<framework::OpDesc> Apply() const override {
|
||||
auto* op = new framework::OpDesc();
|
||||
op->SetType("pixel_shuffle_grad");
|
||||
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
||||
op->SetAttrMap(Attrs());
|
||||
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
||||
return std::unique_ptr<framework::OpDesc>(op);
|
||||
}
|
||||
};
|
||||
|
||||
class PixelShuffleGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@Grad) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Output(X@Grad) should not be null");
|
||||
|
||||
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW.");
|
||||
|
||||
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
|
||||
|
||||
auto dx_dims = do_dims;
|
||||
dx_dims[0] = do_dims[0];
|
||||
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
|
||||
dx_dims[2] = do_dims[2] / upscale_factor;
|
||||
dx_dims[3] = do_dims[3] / upscale_factor;
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
|
||||
ops::PixelShuffleGradMaker);
|
||||
|
||||
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
pixel_shuffle,
|
||||
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
pixel_shuffle_grad,
|
||||
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,26 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/pixel_shuffle_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
pixel_shuffle, ops::PixelShuffleOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::PixelShuffleOpKernel<plat::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
pixel_shuffle_grad,
|
||||
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,82 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class PixelShuffleOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int factor = ctx.Attr<int>("upscale_factor");
|
||||
|
||||
auto in_dims = in->dims();
|
||||
auto o_dims = out->dims();
|
||||
|
||||
framework::Tensor t;
|
||||
t.ShareDataWith(*in);
|
||||
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
|
||||
|
||||
std::vector<int> axis = {0, 1, 4, 2, 5, 3};
|
||||
|
||||
framework::Tensor o;
|
||||
o.ShareDataWith(*out);
|
||||
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
|
||||
|
||||
math::Transpose<DeviceContext, T, 6> trans;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
trans(dev_ctx, t, &o, axis);
|
||||
out->Resize(o_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
dx->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int factor = ctx.Attr<int>("upscale_factor");
|
||||
|
||||
auto do_dims = dout->dims();
|
||||
auto dx_dims = dx->dims();
|
||||
|
||||
framework::Tensor t;
|
||||
t.ShareDataWith(*dout);
|
||||
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
|
||||
|
||||
std::vector<int> axis = {0, 1, 3, 5, 2, 4};
|
||||
|
||||
framework::Tensor o;
|
||||
o.ShareDataWith(*dx);
|
||||
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
|
||||
|
||||
math::Transpose<DeviceContext, T, 6> trans;
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
trans(dev_ctx, t, &o, axis);
|
||||
dx->Resize(dx_dims);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestPixelShuffle(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "pixel_shuffle"
|
||||
n, c, h, w = 2, 9, 4, 4
|
||||
up_factor = 3
|
||||
shape = [n, c, h, w]
|
||||
x = np.random.random(shape).astype("float32")
|
||||
new_shape = (n, c // (up_factor * up_factor), up_factor, up_factor, h,
|
||||
w)
|
||||
# reshape to (num,output_channel,upscale_factor,upscale_factor,h,w)
|
||||
npresult = np.reshape(x, new_shape)
|
||||
# transpose to (num,output_channel,h,upscale_factor,w,upscale_factor)
|
||||
npresult = npresult.transpose(0, 1, 4, 2, 5, 3)
|
||||
oshape = [n, c // (up_factor * up_factor), h * up_factor, w * up_factor]
|
||||
npresult = np.reshape(npresult, oshape)
|
||||
|
||||
self.inputs = {'X': x}
|
||||
self.outputs = {'Out': npresult}
|
||||
self.attrs = {'upscale_factor': up_factor}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue