commit
b03a44e062
@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/space_to_depth_op.h"
|
||||||
|
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
space_to_depth,
|
||||||
|
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, double>,
|
||||||
|
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(
|
||||||
|
space_to_depth_grad,
|
||||||
|
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||||
|
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, double>,
|
||||||
|
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
||||||
@ -0,0 +1,127 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
|
||||||
|
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
|
||||||
|
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/platform/for_range.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class space_to_depth_compute {
|
||||||
|
public:
|
||||||
|
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
|
||||||
|
int64_t batch, int64_t blocksize,
|
||||||
|
int64_t forward, T *out)
|
||||||
|
: x_(x),
|
||||||
|
w_(w),
|
||||||
|
h_(h),
|
||||||
|
c_(c),
|
||||||
|
batch_(batch),
|
||||||
|
blocksize_(blocksize),
|
||||||
|
forward_(forward),
|
||||||
|
out_(out) {}
|
||||||
|
|
||||||
|
HOSTDEVICE void operator()(int64_t in_index) {
|
||||||
|
int64_t out_c = c_ / (blocksize_ * blocksize_);
|
||||||
|
// calculate each dim position with index of tensor
|
||||||
|
int64_t b = in_index / (c_ * h_ * w_);
|
||||||
|
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
|
||||||
|
int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_;
|
||||||
|
int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_;
|
||||||
|
|
||||||
|
int64_t c2 = k % out_c;
|
||||||
|
int64_t offset = k / out_c;
|
||||||
|
int64_t w2 = i * blocksize_ + offset % blocksize_;
|
||||||
|
int64_t h2 = j * blocksize_ + offset / blocksize_;
|
||||||
|
int64_t out_index =
|
||||||
|
w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b));
|
||||||
|
if (forward_)
|
||||||
|
out_[out_index] = x_[in_index];
|
||||||
|
else
|
||||||
|
out_[in_index] = x_[out_index];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const T *x_;
|
||||||
|
int64_t w_, h_, c_, batch_, blocksize_, forward_;
|
||||||
|
T *out_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class SpaceToDepthKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto *out = context.Output<framework::LoDTensor>("Out");
|
||||||
|
auto *x = context.Input<framework::LoDTensor>("X");
|
||||||
|
auto blocksize = context.Attr<int64_t>("blocksize");
|
||||||
|
auto in_dims = x->dims();
|
||||||
|
out->mutable_data(context.GetPlace(), x->type());
|
||||||
|
|
||||||
|
auto out_dims = out->dims();
|
||||||
|
auto B = in_dims[0];
|
||||||
|
auto C = in_dims[1];
|
||||||
|
auto H = in_dims[2];
|
||||||
|
auto W = in_dims[3];
|
||||||
|
platform::ForRange<DeviceContext> for_range(
|
||||||
|
context.template device_context<DeviceContext>(),
|
||||||
|
static_cast<size_t>(x->numel()));
|
||||||
|
|
||||||
|
auto *x_data = x->data<T>();
|
||||||
|
auto *out_data = out->data<T>();
|
||||||
|
paddle::operators::space_to_depth_compute<T> computer(
|
||||||
|
x_data, W, H, C, B, blocksize, 1, out_data);
|
||||||
|
for_range(computer);
|
||||||
|
|
||||||
|
out->Resize(out_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class SpaceToDepthGradKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto *d_out =
|
||||||
|
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
|
||||||
|
auto *d_x =
|
||||||
|
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
||||||
|
auto blocksize = context.Attr<int64_t>("blocksize");
|
||||||
|
auto in_dims = d_x->dims();
|
||||||
|
d_x->mutable_data(context.GetPlace(), d_out->type());
|
||||||
|
|
||||||
|
auto B = in_dims[0];
|
||||||
|
auto C = in_dims[1];
|
||||||
|
auto H = in_dims[2];
|
||||||
|
auto W = in_dims[3];
|
||||||
|
|
||||||
|
platform::ForRange<DeviceContext> for_range(
|
||||||
|
context.template device_context<DeviceContext>(),
|
||||||
|
static_cast<size_t>(d_x->numel()));
|
||||||
|
|
||||||
|
auto *dx_data = d_x->data<T>();
|
||||||
|
auto *dout_data = d_out->data<T>();
|
||||||
|
|
||||||
|
paddle::operators::space_to_depth_compute<T> computer(
|
||||||
|
dout_data, W, H, C, B, blocksize, 0, dx_data);
|
||||||
|
for_range(computer);
|
||||||
|
|
||||||
|
d_x->Resize(in_dims);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpaceToDepthOp(OpTest):
|
||||||
|
@staticmethod
|
||||||
|
def helper(in_, width, height, channel, batch, blocksize, forward, out_):
|
||||||
|
channel_out = channel // (blocksize * blocksize)
|
||||||
|
for b in range(batch):
|
||||||
|
for k in range(channel):
|
||||||
|
for j in range(height):
|
||||||
|
for i in range(width):
|
||||||
|
in_index = i + width * (j + height * (k + channel * b))
|
||||||
|
channel2 = k % channel_out
|
||||||
|
offset = k // channel_out
|
||||||
|
width2 = i * blocksize + offset % blocksize
|
||||||
|
height2 = j * blocksize + offset // blocksize
|
||||||
|
out_index = width2 + width * blocksize * (
|
||||||
|
height2 + height * blocksize *
|
||||||
|
(channel2 + channel_out * b))
|
||||||
|
if forward:
|
||||||
|
out_[out_index] = in_[in_index]
|
||||||
|
else:
|
||||||
|
out_[in_index] = in_[out_index]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.init_data()
|
||||||
|
|
||||||
|
self.op_type = "space_to_depth"
|
||||||
|
self.inputs = {"X": self.x}
|
||||||
|
self.helper(self.x_1d, self.x.shape[3], self.x.shape[2],
|
||||||
|
self.x.shape[1], self.x.shape[0], self.blocksize,
|
||||||
|
self.forward, self.out_1d)
|
||||||
|
self.out = np.reshape(self.out_1d, self.infered_shape)
|
||||||
|
self.attrs = {"blocksize": self.blocksize}
|
||||||
|
self.outputs = {"Out": self.out}
|
||||||
|
|
||||||
|
def init_data(self):
|
||||||
|
self.ori_shape = (32, 12, 6, 6)
|
||||||
|
self.infered_shape = (32, 48, 3, 3)
|
||||||
|
self.one_d_len = 32 * 48 * 3 * 3
|
||||||
|
|
||||||
|
self.blocksize = 2
|
||||||
|
self.x = np.random.random(self.ori_shape).astype('float32')
|
||||||
|
self.x_1d = np.reshape(self.x, self.one_d_len)
|
||||||
|
self.out = np.zeros(self.infered_shape).astype('float32')
|
||||||
|
self.out_1d = np.reshape(self.out, self.one_d_len)
|
||||||
|
self.forward = 1
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
|
||||||
|
) else fluid.core.CPUPlace()
|
||||||
|
self.check_output_with_place(place, 1e-5, None, False)
|
||||||
|
|
||||||
|
def test_check_grad(self):
|
||||||
|
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
|
||||||
|
) else fluid.core.CPUPlace()
|
||||||
|
self.check_grad_with_place(place, ['X'], 'Out')
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpaceToDepthOpBasic(TestSpaceToDepthOp):
|
||||||
|
def init_data(self):
|
||||||
|
self.ori_shape = (32, 8, 6, 6)
|
||||||
|
self.infered_shape = (32, 32, 3, 3)
|
||||||
|
self.one_d_len = 32 * 32 * 3 * 3
|
||||||
|
|
||||||
|
self.blocksize = 2
|
||||||
|
self.x = np.random.random(self.ori_shape).astype('float32')
|
||||||
|
self.x_1d = np.reshape(self.x, self.one_d_len)
|
||||||
|
self.out = np.zeros(self.infered_shape).astype('float32')
|
||||||
|
self.out_1d = np.reshape(self.out, self.one_d_len)
|
||||||
|
self.forward = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp):
|
||||||
|
def init_data(self):
|
||||||
|
self.ori_shape = (32, 8, 6, 6)
|
||||||
|
self.infered_shape = (32, 32, 3, 3)
|
||||||
|
self.one_d_len = 32 * 32 * 3 * 3
|
||||||
|
|
||||||
|
self.blocksize = 2
|
||||||
|
self.x = np.random.random(self.ori_shape).astype('float64')
|
||||||
|
self.x_1d = np.reshape(self.x, self.one_d_len)
|
||||||
|
self.out = np.zeros(self.infered_shape).astype('float64')
|
||||||
|
self.out_1d = np.reshape(self.out, self.one_d_len)
|
||||||
|
self.forward = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp):
|
||||||
|
def init_data(self):
|
||||||
|
self.ori_shape = (32, 9, 6, 6)
|
||||||
|
self.infered_shape = (32, 81, 2, 2)
|
||||||
|
self.one_d_len = 32 * 81 * 2 * 2
|
||||||
|
|
||||||
|
self.blocksize = 3
|
||||||
|
self.x = np.random.random(self.ori_shape).astype('float32')
|
||||||
|
self.x_1d = np.reshape(self.x, self.one_d_len)
|
||||||
|
self.out = np.zeros(self.infered_shape).astype('float32')
|
||||||
|
self.out_1d = np.reshape(self.out, self.one_d_len)
|
||||||
|
self.forward = 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp):
|
||||||
|
def init_data(self):
|
||||||
|
self.ori_shape = (32, 9, 9, 6)
|
||||||
|
self.infered_shape = (32, 81, 3, 2)
|
||||||
|
self.one_d_len = 32 * 81 * 3 * 2
|
||||||
|
|
||||||
|
self.blocksize = 3
|
||||||
|
self.x = np.random.random(self.ori_shape).astype('float32')
|
||||||
|
self.x_1d = np.reshape(self.x, self.one_d_len)
|
||||||
|
self.out = np.zeros(self.infered_shape).astype('float32')
|
||||||
|
self.out_1d = np.reshape(self.out, self.one_d_len)
|
||||||
|
self.forward = 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue