From 70351de1b5c29162247dd9f6f0da1f30a617d51b Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 22 Oct 2018 12:22:23 +0000 Subject: [PATCH 01/14] test=develop --- paddle/fluid/operators/reorg_op.cc | 127 ++++++++++++++++++ paddle/fluid/operators/reorg_op.cu | 29 ++++ paddle/fluid/operators/reorg_op.h | 126 +++++++++++++++++ python/paddle/fluid/layers/nn.py | 52 +++++++ python/paddle/fluid/op.py | 2 + .../fluid/tests/unittests/test_layers.py | 11 ++ .../fluid/tests/unittests/test_reorg_op.py | 93 +++++++++++++ 7 files changed, 440 insertions(+) create mode 100644 paddle/fluid/operators/reorg_op.cc create mode 100644 paddle/fluid/operators/reorg_op.cu create mode 100644 paddle/fluid/operators/reorg_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_reorg_op.py diff --git a/paddle/fluid/operators/reorg_op.cc b/paddle/fluid/operators/reorg_op.cc new file mode 100644 index 0000000000..1f9da1f797 --- /dev/null +++ b/paddle/fluid/operators/reorg_op.cc @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reorg_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class ReorgOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of reorgOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of reorgOp should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor"); + auto stride = ctx->Attrs().Get("stride"); + + PADDLE_ENFORCE_GT(stride, 0, "The stride should be Greater than 0"); + PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0"); + PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0"); + PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0"); + + PADDLE_ENFORCE_EQ( + x_dims[1] % (stride * stride), 0, + "input channel should be dvisible of the square of reorg stride"); + PADDLE_ENFORCE_EQ( + x_dims[2] % (stride), 0, + "input Height should be dvisible of the square of reorg stride"); + PADDLE_ENFORCE_EQ( + x_dims[3] % (stride), 0, + "input Width should be dvisible of the square of reorg stride"); + + VLOG(3) << "reorg operator x.shape=" << x_dims << "Attribute stride" + << stride << std::endl; + + std::vector output_shape(4, 0); // [B,C,H,W] + output_shape[0] = x_dims[0]; + output_shape[1] = x_dims[1] * stride * stride; + output_shape[2] = x_dims[2] / stride; + output_shape[3] = x_dims[3] / stride; + + auto out_dims = framework::make_ddim(output_shape); + + ctx->SetOutputDim("Out", out_dims); + + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } +}; + +class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor). The input should be a 4D tensor B * C * W * H of reorg " + "operator."); + AddOutput("Out", + "(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of " + "reorg operator."); + AddAttr("stride", + "(int64_t, default 1) stride used to do reorgnization.") + .SetDefault(1) + .EqualGreaterThan(1); + AddComment(R"DOC( + reorg operator used in Yolo v2. + The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride, + + Reshape Input(X) into the shape according to Attr(stride). The + data in Input(X) are unchanged. + + Examples: + + 1. Given a 3-D tensor Input(X) with a shape [2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X) + into a 3-D tensor with shape [2048, 13, 13] and leaving Input(X)'s data unchanged. + + )DOC"); + } +}; + +class ReorgGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(reorg, ops::ReorgOp, ops::ReorgOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(reorg_grad, ops::ReorgGradOp); +REGISTER_OP_CPU_KERNEL( + reorg, ops::ReorgKernel, + ops::ReorgKernel, + ops::ReorgKernel); +REGISTER_OP_CPU_KERNEL( + reorg_grad, ops::ReorgGradKernel, + ops::ReorgGradKernel, + ops::ReorgGradKernel); diff --git a/paddle/fluid/operators/reorg_op.cu b/paddle/fluid/operators/reorg_op.cu new file mode 100644 index 0000000000..de1c7d7468 --- /dev/null +++ b/paddle/fluid/operators/reorg_op.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/reorg_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + reorg, ops::ReorgKernel, + ops::ReorgKernel, + ops::ReorgKernel); + +REGISTER_OP_CUDA_KERNEL( + reorg_grad, + ops::ReorgGradKernel, + ops::ReorgGradKernel, + ops::ReorgGradKernel); diff --git a/paddle/fluid/operators/reorg_op.h b/paddle/fluid/operators/reorg_op.h new file mode 100644 index 0000000000..108437b4d8 --- /dev/null +++ b/paddle/fluid/operators/reorg_op.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifndef PADDLE_FLUID_OPERATORS_REORG_OP_H_ +#define PADDLE_FLUID_OPERATORS_REORG_OP_H_ +#endif // PADDLE_FLUID_OPERATORS_REORG_OP_H_ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +class reorg_cpu { + public: + HOSTDEVICE reorg_cpu(const T *x, int64_t w, int64_t h, int64_t c, + int64_t batch, int64_t stride, int64_t forward, T *out) + : x_(x), + w_(w), + h_(h), + c_(c), + batch_(batch), + stride_(stride), + forward_(forward), + out_(out) {} + + HOSTDEVICE void operator()(int64_t in_index) { + int64_t out_c = c_ / (stride_ * stride_); + // calculate each dim position with index of tensor + int64_t b = in_index / (c_ * h_ * w_); + int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_); + int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_; + int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_; + + int64_t c2 = k % out_c; + int64_t offset = k / out_c; + int64_t w2 = i * stride_ + offset % stride_; + int64_t h2 = j * stride_ + offset / stride_; + int64_t out_index = + w2 + w_ * stride_ * (h2 + h_ * stride_ * (c2 + out_c * b)); + if (forward_) + out_[out_index] = x_[in_index]; + else + out_[in_index] = x_[out_index]; + } + + private: + const T *x_; + int64_t w_, h_, c_, batch_, stride_, forward_; + T *out_; +}; + +template +class ReorgKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *out = context.Output("Out"); + auto *x = context.Input("X"); + auto stride = context.Attr("stride"); + auto in_dims = x->dims(); + out->mutable_data(context.GetPlace(), x->type()); + + auto out_dims = out->dims(); + auto B = in_dims[0]; + auto C = in_dims[1]; + auto H = in_dims[2]; + auto W = in_dims[3]; + platform::ForRange for_range( + context.template device_context(), + static_cast(x->numel())); + + auto *x_data = x->data(); + auto *out_data = out->data(); + paddle::operators::reorg_cpu reorg(x_data, W, H, C, B, stride, 1, + out_data); + for_range(reorg); + + out->Resize(out_dims); + } +}; + +template +class ReorgGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *d_out = + context.Input(framework::GradVarName("Out")); + auto *d_x = + context.Output(framework::GradVarName("X")); + auto stride = context.Attr("stride"); + auto in_dims = d_x->dims(); + d_x->mutable_data(context.GetPlace(), d_out->type()); + + auto B = in_dims[0]; + auto C = in_dims[1]; + auto H = in_dims[2]; + auto W = in_dims[3]; + + platform::ForRange for_range( + context.template device_context(), + static_cast(d_x->numel())); + + auto *dx_data = d_x->data(); + auto *dout_data = d_out->data(); + + paddle::operators::reorg_cpu reorg(dout_data, W, H, C, B, stride, 0, + dx_data); + for_range(reorg); + + d_x->Resize(in_dims); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8c0ef7a824..35a1a899e7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -150,6 +150,7 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', + 'reorg', ] @@ -7084,3 +7085,54 @@ def maxout(x, groups, name=None): attrs={"groups": groups}, outputs={"Out": out}) return out + + +def reorg(x, stride, name=None): + """ + Gives a stride to reorg the input tensor + + Here are some example: + + input is 4D LoDtensor with shape [batch, channel, height, width] and has an attrs stride = 2 + + reorg will do some math work to reorder the elements of input according to stride to construt + put with shape [batch, channel * stride * stride, height/stride, width/stride] + + reorg is used to reorgnization the output of pre_layer and change the tensor to fit the shape + + Args: + x(variable): The input tensor. + stride(variable): The stride to reorg + + Returns: + Variable: The output tensor. + + Raises: + TypeError: stride type must be a long. + + Examples: + .. code-block:: python + + data = fluid.layers.data( + name='data', shape=[1, 4, 2, 2], dtype='float32') + reorged = fluid.layers.reorged( + x=data, stride=2) + """ + + if not (isinstance(stride, long)): + raise ValueError("stride must be a python long") + + helper = LayerHelper("reorg", **locals()) + if name is None: + out = helper.create_tmp_variable(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="reorg", + inputs={"X": x}, + attrs={"stride": stride}, + outputs={"Out": out}) + + return out diff --git a/python/paddle/fluid/op.py b/python/paddle/fluid/op.py index 667db10d3e..52b169fb3c 100644 --- a/python/paddle/fluid/op.py +++ b/python/paddle/fluid/op.py @@ -108,6 +108,8 @@ class OpDescCreationMethod(object): new_attr.i = user_defined_attr elif attr.type == framework_pb2.FLOAT: new_attr.f = user_defined_attr + elif attr.type == framework_pb2.LONG: + new_attr.l = user_defined_attr elif attr.type == framework_pb2.STRING: new_attr.s = user_defined_attr elif attr.type == framework_pb2.BOOLEAN: diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 1d8d0b55f0..f34c385617 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -240,6 +240,17 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.softmax(hid)) print(str(program)) + def test_reorg(self): + program = Program() + with program_guard(program): + data = layers.data( + name="data", + shape=[32, 9, 6, 6], + append_batch_size=False, + dtype='float32') + self.assertIsNotNone(layers.reorg(data, long(3))) + print(str(program)) + def test_sequence_unsqueeze(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_reorg_op.py b/python/paddle/fluid/tests/unittests/test_reorg_op.py new file mode 100644 index 0000000000..9d4fa4d0ff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_reorg_op.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle.fluid as fluid +from op_test import OpTest + + +class TestReorgOp(OpTest): + @staticmethod + def helper(in_, width, height, channel, batch, stride, forward, out_): + channel_out = channel / (stride * stride) + for b in range(batch): + for k in range(channel): + for j in range(height): + for i in range(width): + in_index = i + width * (j + height * (k + channel * b)) + channel2 = k % channel_out + offset = k / channel_out + width2 = i * stride + offset % stride + height2 = j * stride + offset / stride + out_index = width2 + width * stride * ( + height2 + height * stride * + (channel2 + channel_out * b)) + if forward: + out_[out_index] = in_[in_index] + else: + out_[in_index] = in_[out_index] + + def setUp(self): + self.init_data() + + self.op_type = "reorg" + self.inputs = {"X": self.x} + self.helper(self.x_1d, self.x.shape[3], self.x.shape[2], + self.x.shape[1], self.x.shape[0], self.stride, self.forward, + self.out_1d) + self.out = np.reshape(self.out_1d, self.infered_shape) + self.attrs = {"stride": long(self.stride)} + self.outputs = {"Out": self.out} + + def init_data(self): + self.ori_shape = (32, 12, 6, 6) + self.infered_shape = (32, 48, 3, 3) + self.one_d_len = 32 * 48 * 3 * 3 + + self.stride = 2 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + def test_check_output(self): + place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.core.CPUPlace() + self.check_output_with_place(place, 1e-5, None, False) + + def test_check_grad(self): + place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.core.CPUPlace() + self.check_grad_with_place(place, ['X'], 'Out') + + +class TestReorgOp2(TestReorgOp): + def init_data(self): + self.ori_shape = (32, 9, 6, 6) + self.infered_shape = (32, 81, 2, 2) + self.one_d_len = 32 * 81 * 2 * 2 + + self.stride = 3 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + +if __name__ == '__main__': + unittest.main() From ff07dc315ec5351c84754de8b4e8f944e44628db Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 23 Oct 2018 06:43:46 +0000 Subject: [PATCH 02/14] test=develop --- paddle/fluid/operators/reorg_op.cc | 4 ++-- python/paddle/fluid/layers/nn.py | 4 ++-- python/paddle/fluid/tests/unittests/test_layers.py | 2 +- python/paddle/fluid/tests/unittests/test_reorg_op.py | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/reorg_op.cc b/paddle/fluid/operators/reorg_op.cc index 1f9da1f797..757761ab51 100644 --- a/paddle/fluid/operators/reorg_op.cc +++ b/paddle/fluid/operators/reorg_op.cc @@ -91,8 +91,8 @@ class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { Examples: - 1. Given a 3-D tensor Input(X) with a shape [2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X) - into a 3-D tensor with shape [2048, 13, 13] and leaving Input(X)'s data unchanged. + 1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X) + into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged. )DOC"); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f04b268626..d112793c71 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7470,8 +7470,8 @@ def reorg(x, stride, name=None): x=data, stride=2) """ - if not (isinstance(stride, long)): - raise ValueError("stride must be a python long") + if not (isinstance(stride, int)): + raise ValueError("stride must be a python Int") helper = LayerHelper("reorg", **locals()) if name is None: diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index cc354c9005..e59f56b455 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -256,7 +256,7 @@ class TestBook(unittest.TestCase): shape=[32, 9, 6, 6], append_batch_size=False, dtype='float32') - self.assertIsNotNone(layers.reorg(data, long(3))) + self.assertIsNotNone(layers.reorg(data, 3)) print(str(program)) def test_sequence_unsqueeze(self): diff --git a/python/paddle/fluid/tests/unittests/test_reorg_op.py b/python/paddle/fluid/tests/unittests/test_reorg_op.py index 9d4fa4d0ff..b773606fe3 100644 --- a/python/paddle/fluid/tests/unittests/test_reorg_op.py +++ b/python/paddle/fluid/tests/unittests/test_reorg_op.py @@ -22,16 +22,16 @@ from op_test import OpTest class TestReorgOp(OpTest): @staticmethod def helper(in_, width, height, channel, batch, stride, forward, out_): - channel_out = channel / (stride * stride) + channel_out = channel // (stride * stride) for b in range(batch): for k in range(channel): for j in range(height): for i in range(width): in_index = i + width * (j + height * (k + channel * b)) channel2 = k % channel_out - offset = k / channel_out + offset = k // channel_out width2 = i * stride + offset % stride - height2 = j * stride + offset / stride + height2 = j * stride + offset // stride out_index = width2 + width * stride * ( height2 + height * stride * (channel2 + channel_out * b)) From 70b52733630d4ef34b12fdf9dce65ca3cf0d4415 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 23 Oct 2018 10:55:43 +0000 Subject: [PATCH 03/14] test=develop --- python/paddle/fluid/tests/unittests/test_reorg_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_reorg_op.py b/python/paddle/fluid/tests/unittests/test_reorg_op.py index b773606fe3..a3afabe7af 100644 --- a/python/paddle/fluid/tests/unittests/test_reorg_op.py +++ b/python/paddle/fluid/tests/unittests/test_reorg_op.py @@ -49,7 +49,7 @@ class TestReorgOp(OpTest): self.x.shape[1], self.x.shape[0], self.stride, self.forward, self.out_1d) self.out = np.reshape(self.out_1d, self.infered_shape) - self.attrs = {"stride": long(self.stride)} + self.attrs = {"stride": self.stride} self.outputs = {"Out": self.out} def init_data(self): From eab2e5e5d4160aacca7f499920e6570a8ddfeb32 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 23 Oct 2018 10:56:41 +0000 Subject: [PATCH 04/14] test=develop --- python/paddle/fluid/tests/unittests/test_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e59f56b455..92c60da715 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -252,7 +252,7 @@ class TestBook(unittest.TestCase): program = Program() with program_guard(program): data = layers.data( - name="data", + name='data', shape=[32, 9, 6, 6], append_batch_size=False, dtype='float32') From 6259dba5bdd1cb92decbf6c6ba8a0f6090899545 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 23 Oct 2018 11:01:14 +0000 Subject: [PATCH 05/14] test=develop --- python/paddle/fluid/layers/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d112793c71..16ea7a7ddf 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7470,10 +7470,11 @@ def reorg(x, stride, name=None): x=data, stride=2) """ + helper = LayerHelper("reorg", **locals()) + if not (isinstance(stride, int)): raise ValueError("stride must be a python Int") - helper = LayerHelper("reorg", **locals()) if name is None: out = helper.create_tmp_variable(dtype=x.dtype) else: From 782ef3c5dc8b2b0140d7467a39b98c1c7074cc67 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 23 Oct 2018 12:05:39 +0000 Subject: [PATCH 06/14] test=develop --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index f71bd894c0..97068f1979 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7494,7 +7494,7 @@ def reorg(x, stride, name=None): raise ValueError("stride must be a python Int") if name is None: - out = helper.create_tmp_variable(dtype=x.dtype) + out = helper.create_variable_for_type_inference(dtype=x.dtype) else: out = helper.create_variable( name=name, dtype=x.dtype, persistable=False) From ab808c36dad151653566c539e814d7309b9ddf96 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Tue, 23 Oct 2018 12:37:12 +0000 Subject: [PATCH 07/14] test=develop --- python/paddle/fluid/layers/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 97068f1979..e7f343508a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7494,7 +7494,8 @@ def reorg(x, stride, name=None): raise ValueError("stride must be a python Int") if name is None: - out = helper.create_variable_for_type_inference(dtype=x.dtype) + out = helper.create_variable_for_type_inference( + dtype=x.dtype) #fix create else: out = helper.create_variable( name=name, dtype=x.dtype, persistable=False) From c056328563e87ab9d2b14d50d070f4c6d139afe0 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Wed, 24 Oct 2018 05:20:30 +0000 Subject: [PATCH 08/14] test=develop --- paddle/fluid/API.spec | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 19ef23cdfa..5c4aa6158e 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.reorg ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) From 9cad409f2a9a76d918431ec85754cff7dcf5bcb4 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 25 Oct 2018 09:03:31 +0000 Subject: [PATCH 09/14] test=develop --- paddle/fluid/API.spec | 2 +- .../{reorg_op.cc => space_to_depth_op.cc} | 68 ++++++++++--------- .../{reorg_op.cu => space_to_depth_op.cu} | 17 ++--- .../{reorg_op.h => space_to_depth_op.h} | 29 ++++---- python/paddle/fluid/layers/nn.py | 37 +++++----- .../fluid/tests/unittests/test_layers.py | 4 +- ..._reorg_op.py => test_space_to_depth_op.py} | 48 ++++++++++++- 7 files changed, 127 insertions(+), 78 deletions(-) rename paddle/fluid/operators/{reorg_op.cc => space_to_depth_op.cc} (62%) rename paddle/fluid/operators/{reorg_op.cu => space_to_depth_op.cu} (57%) rename paddle/fluid/operators/{reorg_op.h => space_to_depth_op.h} (79%) rename python/paddle/fluid/tests/unittests/{test_reorg_op.py => test_space_to_depth_op.py} (67%) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 5c4aa6158e..3ac9fe31b4 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -174,7 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.reorg ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/reorg_op.cc b/paddle/fluid/operators/space_to_depth_op.cc similarity index 62% rename from paddle/fluid/operators/reorg_op.cc rename to paddle/fluid/operators/space_to_depth_op.cc index 757761ab51..a9a266a3f7 100644 --- a/paddle/fluid/operators/reorg_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -12,44 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/reorg_op.h" +#include "paddle/fluid/operators/space_to_depth_op.h" #include #include namespace paddle { namespace operators { -class ReorgOp : public framework::OperatorWithKernel { +class SpaceToDepthOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of reorgOp should not be null."); + "Input(X) of SpaceToDepthOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of reorgOp should not be null."); + "Output(Out) of SpaceToDepthOp should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor"); auto stride = ctx->Attrs().Get("stride"); - PADDLE_ENFORCE_GT(stride, 0, "The stride should be Greater than 0"); + PADDLE_ENFORCE_GT(stride, 1, "The stride should be Greater than 1"); PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0"); - PADDLE_ENFORCE_EQ( - x_dims[1] % (stride * stride), 0, - "input channel should be dvisible of the square of reorg stride"); - PADDLE_ENFORCE_EQ( - x_dims[2] % (stride), 0, - "input Height should be dvisible of the square of reorg stride"); - PADDLE_ENFORCE_EQ( - x_dims[3] % (stride), 0, - "input Width should be dvisible of the square of reorg stride"); + PADDLE_ENFORCE_EQ(x_dims[1] % (stride * stride), 0, + "input channel should be divisible of the square of " + "SpaceToDepthOp stride"); + PADDLE_ENFORCE_EQ(x_dims[2] % (stride), 0, + "input Height should be divisible of the square of " + "SpaceToDepthOp stride"); + PADDLE_ENFORCE_EQ(x_dims[3] % (stride), 0, + "input Width should be divisible of the square of " + "SpaceToDepthOp stride"); - VLOG(3) << "reorg operator x.shape=" << x_dims << "Attribute stride" - << stride << std::endl; + VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims + << "Attribute stride" << stride << std::endl; std::vector output_shape(4, 0); // [B,C,H,W] output_shape[0] = x_dims[0]; @@ -69,19 +69,21 @@ class ReorgOp : public framework::OperatorWithKernel { } }; -class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { +class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor). The input should be a 4D tensor B * C * W * H of reorg " + "(Tensor). The input should be a 4D tensor B * C * W * H of " + "SpaceToDepthOp " "operator."); AddOutput("Out", "(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of " - "reorg operator."); - AddAttr("stride", - "(int64_t, default 1) stride used to do reorgnization.") - .SetDefault(1) - .EqualGreaterThan(1); + "SpaceToDepthOp operator."); + AddAttr( + "stride", + "(int64_t, default 2) stride used to do change Space To Depth.") + .SetDefault(2) + .GreaterThan(1); AddComment(R"DOC( reorg operator used in Yolo v2. The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride, @@ -98,7 +100,7 @@ class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class ReorgGradOp : public framework::OperatorWithKernel { +class SpaceToDepthGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -114,14 +116,16 @@ class ReorgGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OPERATOR(reorg, ops::ReorgOp, ops::ReorgOpMaker, +REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(reorg_grad, ops::ReorgGradOp); +REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp); REGISTER_OP_CPU_KERNEL( - reorg, ops::ReorgKernel, - ops::ReorgKernel, - ops::ReorgKernel); + space_to_depth, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel); REGISTER_OP_CPU_KERNEL( - reorg_grad, ops::ReorgGradKernel, - ops::ReorgGradKernel, - ops::ReorgGradKernel); + space_to_depth_grad, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel); diff --git a/paddle/fluid/operators/reorg_op.cu b/paddle/fluid/operators/space_to_depth_op.cu similarity index 57% rename from paddle/fluid/operators/reorg_op.cu rename to paddle/fluid/operators/space_to_depth_op.cu index de1c7d7468..38d0a66273 100644 --- a/paddle/fluid/operators/reorg_op.cu +++ b/paddle/fluid/operators/space_to_depth_op.cu @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reorg_op.h" +#include "paddle/fluid/operators/space_to_depth_op.h" namespace plat = paddle::platform; namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - reorg, ops::ReorgKernel, - ops::ReorgKernel, - ops::ReorgKernel); + space_to_depth, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel); REGISTER_OP_CUDA_KERNEL( - reorg_grad, - ops::ReorgGradKernel, - ops::ReorgGradKernel, - ops::ReorgGradKernel); + space_to_depth_grad, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel); diff --git a/paddle/fluid/operators/reorg_op.h b/paddle/fluid/operators/space_to_depth_op.h similarity index 79% rename from paddle/fluid/operators/reorg_op.h rename to paddle/fluid/operators/space_to_depth_op.h index 108437b4d8..a236c1d5b7 100644 --- a/paddle/fluid/operators/reorg_op.h +++ b/paddle/fluid/operators/space_to_depth_op.h @@ -11,9 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef PADDLE_FLUID_OPERATORS_REORG_OP_H_ -#define PADDLE_FLUID_OPERATORS_REORG_OP_H_ -#endif // PADDLE_FLUID_OPERATORS_REORG_OP_H_ +#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_ +#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_ +#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" @@ -22,10 +22,11 @@ namespace paddle { namespace operators { template -class reorg_cpu { +class space_to_depth_compute { public: - HOSTDEVICE reorg_cpu(const T *x, int64_t w, int64_t h, int64_t c, - int64_t batch, int64_t stride, int64_t forward, T *out) + HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c, + int64_t batch, int64_t stride, + int64_t forward, T *out) : x_(x), w_(w), h_(h), @@ -62,7 +63,7 @@ class reorg_cpu { }; template -class ReorgKernel : public framework::OpKernel { +class SpaceToDepthKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *out = context.Output("Out"); @@ -82,16 +83,16 @@ class ReorgKernel : public framework::OpKernel { auto *x_data = x->data(); auto *out_data = out->data(); - paddle::operators::reorg_cpu reorg(x_data, W, H, C, B, stride, 1, - out_data); - for_range(reorg); + paddle::operators::space_to_depth_compute computer(x_data, W, H, C, B, + stride, 1, out_data); + for_range(computer); out->Resize(out_dims); } }; template -class ReorgGradKernel : public framework::OpKernel { +class SpaceToDepthGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *d_out = @@ -114,9 +115,9 @@ class ReorgGradKernel : public framework::OpKernel { auto *dx_data = d_x->data(); auto *dout_data = d_out->data(); - paddle::operators::reorg_cpu reorg(dout_data, W, H, C, B, stride, 0, - dx_data); - for_range(reorg); + paddle::operators::space_to_depth_compute computer(dout_data, W, H, C, B, + stride, 0, dx_data); + for_range(computer); d_x->Resize(in_dims); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e7f343508a..6688c0e99f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -154,7 +154,7 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', - 'reorg', + 'space_to_depth', 'affine_channel', ] @@ -7456,25 +7456,26 @@ def maxout(x, groups, name=None): return out -def reorg(x, stride, name=None): +def space_to_depth(x, stride, name=None): """ - Gives a stride to reorg the input tensor - - Here are some example: - - input is 4D LoDtensor with shape [batch, channel, height, width] and has an attrs stride = 2 - - reorg will do some math work to reorder the elements of input according to stride to construt - put with shape [batch, channel * stride * stride, height/stride, width/stride] - - reorg is used to reorgnization the output of pre_layer and change the tensor to fit the shape + Gives a stride to space_to_depth the input LoDtensor + + Rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the + input LoDtensor where values from the height and width dimensions are moved to the channel dimension. + The attr stride indicates the input block size. + + space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according + to stride to construct output with shape [batch, channel * stride * stride, height/stride, width/stride]: + + space_to_depth is used to This operation is useful for resizing the activations between convolutions + (but keeping all data) Args: - x(variable): The input tensor. - stride(variable): The stride to reorg + x(variable): The input LoDtensor. + stride(variable): The stride to space_to_depth Returns: - Variable: The output tensor. + Variable: The output LoDtensor. Raises: TypeError: stride type must be a long. @@ -7484,11 +7485,11 @@ def reorg(x, stride, name=None): data = fluid.layers.data( name='data', shape=[1, 4, 2, 2], dtype='float32') - reorged = fluid.layers.reorged( + space_to_depthed = fluid.layers.space_to_depth( x=data, stride=2) """ - helper = LayerHelper("reorg", **locals()) + helper = LayerHelper("space_to_depth", **locals()) if not (isinstance(stride, int)): raise ValueError("stride must be a python Int") @@ -7501,7 +7502,7 @@ def reorg(x, stride, name=None): name=name, dtype=x.dtype, persistable=False) helper.append_op( - type="reorg", + type="space_to_depth", inputs={"X": x}, attrs={"stride": stride}, outputs={"Out": out}) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 92c60da715..9dd733a54d 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -248,7 +248,7 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.softmax(hid)) print(str(program)) - def test_reorg(self): + def test_space_to_depth(self): program = Program() with program_guard(program): data = layers.data( @@ -256,7 +256,7 @@ class TestBook(unittest.TestCase): shape=[32, 9, 6, 6], append_batch_size=False, dtype='float32') - self.assertIsNotNone(layers.reorg(data, 3)) + self.assertIsNotNone(layers.space_to_depth(data, 3)) print(str(program)) def test_sequence_unsqueeze(self): diff --git a/python/paddle/fluid/tests/unittests/test_reorg_op.py b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py similarity index 67% rename from python/paddle/fluid/tests/unittests/test_reorg_op.py rename to python/paddle/fluid/tests/unittests/test_space_to_depth_op.py index a3afabe7af..36c8cd1119 100644 --- a/python/paddle/fluid/tests/unittests/test_reorg_op.py +++ b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py @@ -19,7 +19,7 @@ import paddle.fluid as fluid from op_test import OpTest -class TestReorgOp(OpTest): +class TestSpaceToDepthOp(OpTest): @staticmethod def helper(in_, width, height, channel, batch, stride, forward, out_): channel_out = channel // (stride * stride) @@ -43,7 +43,7 @@ class TestReorgOp(OpTest): def setUp(self): self.init_data() - self.op_type = "reorg" + self.op_type = "space_to_depth" self.inputs = {"X": self.x} self.helper(self.x_1d, self.x.shape[3], self.x.shape[2], self.x.shape[1], self.x.shape[0], self.stride, self.forward, @@ -75,7 +75,35 @@ class TestReorgOp(OpTest): self.check_grad_with_place(place, ['X'], 'Out') -class TestReorgOp2(TestReorgOp): +class TestSpaceToDepthOpBasic(TestSpaceToDepthOp): + def init_data(self): + self.ori_shape = (32, 8, 6, 6) + self.infered_shape = (32, 32, 3, 3) + self.one_d_len = 32 * 32 * 3 * 3 + + self.stride = 2 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + +class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp): + def init_data(self): + self.ori_shape = (32, 8, 6, 6) + self.infered_shape = (32, 32, 3, 3) + self.one_d_len = 32 * 32 * 3 * 3 + + self.stride = 2 + self.x = np.random.random(self.ori_shape).astype('float64') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float64') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + +class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp): def init_data(self): self.ori_shape = (32, 9, 6, 6) self.infered_shape = (32, 81, 2, 2) @@ -89,5 +117,19 @@ class TestReorgOp2(TestReorgOp): self.forward = 1 +class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp): + def init_data(self): + self.ori_shape = (32, 9, 9, 6) + self.infered_shape = (32, 81, 3, 2) + self.one_d_len = 32 * 81 * 3 * 2 + + self.stride = 3 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + if __name__ == '__main__': unittest.main() From 9c010146c33e36103735c93b0cc21b3968447f2c Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 25 Oct 2018 10:45:53 +0000 Subject: [PATCH 10/14] test=develop --- python/paddle/fluid/layers/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 6688c0e99f..d3b5f13b57 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7472,7 +7472,7 @@ def space_to_depth(x, stride, name=None): Args: x(variable): The input LoDtensor. - stride(variable): The stride to space_to_depth + stride(variable): The stride to select the element on each feature map Returns: Variable: The output LoDtensor. From 7bcba47e41f67036e76b52b7042aacf4c4b2eca6 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 25 Oct 2018 11:02:25 +0000 Subject: [PATCH 11/14] test=develop --- paddle/fluid/operators/space_to_depth_op.cc | 2 +- paddle/fluid/operators/space_to_depth_op.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index a9a266a3f7..1cc169bf10 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/paddle/fluid/operators/space_to_depth_op.h b/paddle/fluid/operators/space_to_depth_op.h index a236c1d5b7..4fc24138e6 100644 --- a/paddle/fluid/operators/space_to_depth_op.h +++ b/paddle/fluid/operators/space_to_depth_op.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 9a74c4489f350ad76e737e09ea177cca1cd9411e Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 29 Oct 2018 05:26:40 +0000 Subject: [PATCH 12/14] test=develop --- paddle/fluid/operators/space_to_depth_op.cc | 34 +++++++++---------- paddle/fluid/operators/space_to_depth_op.h | 26 +++++++------- python/paddle/fluid/layers/nn.py | 22 ++++++------ .../tests/unittests/test_space_to_depth_op.py | 28 +++++++-------- 4 files changed, 55 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/space_to_depth_op.cc b/paddle/fluid/operators/space_to_depth_op.cc index 1cc169bf10..f109dd685c 100644 --- a/paddle/fluid/operators/space_to_depth_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -31,31 +31,31 @@ class SpaceToDepthOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor"); - auto stride = ctx->Attrs().Get("stride"); + auto blocksize = ctx->Attrs().Get("blocksize"); - PADDLE_ENFORCE_GT(stride, 1, "The stride should be Greater than 1"); + PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1"); PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0"); - PADDLE_ENFORCE_EQ(x_dims[1] % (stride * stride), 0, + PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0, "input channel should be divisible of the square of " - "SpaceToDepthOp stride"); - PADDLE_ENFORCE_EQ(x_dims[2] % (stride), 0, + "SpaceToDepthOp blocksize"); + PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0, "input Height should be divisible of the square of " - "SpaceToDepthOp stride"); - PADDLE_ENFORCE_EQ(x_dims[3] % (stride), 0, + "SpaceToDepthOp blocksize"); + PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0, "input Width should be divisible of the square of " - "SpaceToDepthOp stride"); + "SpaceToDepthOp blocksize"); VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims - << "Attribute stride" << stride << std::endl; + << "Attribute blocksize" << blocksize << std::endl; std::vector output_shape(4, 0); // [B,C,H,W] output_shape[0] = x_dims[0]; - output_shape[1] = x_dims[1] * stride * stride; - output_shape[2] = x_dims[2] / stride; - output_shape[3] = x_dims[3] / stride; + output_shape[1] = x_dims[1] * blocksize * blocksize; + output_shape[2] = x_dims[2] / blocksize; + output_shape[3] = x_dims[3] / blocksize; auto out_dims = framework::make_ddim(output_shape); @@ -80,20 +80,20 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of " "SpaceToDepthOp operator."); AddAttr( - "stride", - "(int64_t, default 2) stride used to do change Space To Depth.") + "blocksize", + "(int64_t, default 2) blocksize used to do change Space To Depth.") .SetDefault(2) .GreaterThan(1); AddComment(R"DOC( reorg operator used in Yolo v2. - The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride, + The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize, - Reshape Input(X) into the shape according to Attr(stride). The + Reshape Input(X) into the shape according to Attr(blocksize). The data in Input(X) are unchanged. Examples: - 1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X) + 1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X) into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged. )DOC"); diff --git a/paddle/fluid/operators/space_to_depth_op.h b/paddle/fluid/operators/space_to_depth_op.h index 4fc24138e6..a71662b481 100644 --- a/paddle/fluid/operators/space_to_depth_op.h +++ b/paddle/fluid/operators/space_to_depth_op.h @@ -25,19 +25,19 @@ template class space_to_depth_compute { public: HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c, - int64_t batch, int64_t stride, + int64_t batch, int64_t blocksize, int64_t forward, T *out) : x_(x), w_(w), h_(h), c_(c), batch_(batch), - stride_(stride), + blocksize_(blocksize), forward_(forward), out_(out) {} HOSTDEVICE void operator()(int64_t in_index) { - int64_t out_c = c_ / (stride_ * stride_); + int64_t out_c = c_ / (blocksize_ * blocksize_); // calculate each dim position with index of tensor int64_t b = in_index / (c_ * h_ * w_); int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_); @@ -46,10 +46,10 @@ class space_to_depth_compute { int64_t c2 = k % out_c; int64_t offset = k / out_c; - int64_t w2 = i * stride_ + offset % stride_; - int64_t h2 = j * stride_ + offset / stride_; + int64_t w2 = i * blocksize_ + offset % blocksize_; + int64_t h2 = j * blocksize_ + offset / blocksize_; int64_t out_index = - w2 + w_ * stride_ * (h2 + h_ * stride_ * (c2 + out_c * b)); + w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b)); if (forward_) out_[out_index] = x_[in_index]; else @@ -58,7 +58,7 @@ class space_to_depth_compute { private: const T *x_; - int64_t w_, h_, c_, batch_, stride_, forward_; + int64_t w_, h_, c_, batch_, blocksize_, forward_; T *out_; }; @@ -68,7 +68,7 @@ class SpaceToDepthKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &context) const override { auto *out = context.Output("Out"); auto *x = context.Input("X"); - auto stride = context.Attr("stride"); + auto blocksize = context.Attr("blocksize"); auto in_dims = x->dims(); out->mutable_data(context.GetPlace(), x->type()); @@ -83,8 +83,8 @@ class SpaceToDepthKernel : public framework::OpKernel { auto *x_data = x->data(); auto *out_data = out->data(); - paddle::operators::space_to_depth_compute computer(x_data, W, H, C, B, - stride, 1, out_data); + paddle::operators::space_to_depth_compute computer( + x_data, W, H, C, B, blocksize, 1, out_data); for_range(computer); out->Resize(out_dims); @@ -99,7 +99,7 @@ class SpaceToDepthGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); auto *d_x = context.Output(framework::GradVarName("X")); - auto stride = context.Attr("stride"); + auto blocksize = context.Attr("blocksize"); auto in_dims = d_x->dims(); d_x->mutable_data(context.GetPlace(), d_out->type()); @@ -115,8 +115,8 @@ class SpaceToDepthGradKernel : public framework::OpKernel { auto *dx_data = d_x->data(); auto *dout_data = d_out->data(); - paddle::operators::space_to_depth_compute computer(dout_data, W, H, C, B, - stride, 0, dx_data); + paddle::operators::space_to_depth_compute computer( + dout_data, W, H, C, B, blocksize, 0, dx_data); for_range(computer); d_x->Resize(in_dims); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c762633c60..5659eafd04 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7485,29 +7485,29 @@ def maxout(x, groups, name=None): return out -def space_to_depth(x, stride, name=None): +def space_to_depth(x, blocksize, name=None): """ - Gives a stride to space_to_depth the input LoDtensor + Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width] - Rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the + This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the input LoDtensor where values from the height and width dimensions are moved to the channel dimension. - The attr stride indicates the input block size. + The attr blocksize indicates the input block size. space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according - to stride to construct output with shape [batch, channel * stride * stride, height/stride, width/stride]: + to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]: space_to_depth is used to This operation is useful for resizing the activations between convolutions (but keeping all data) Args: x(variable): The input LoDtensor. - stride(variable): The stride to select the element on each feature map + blocksize(variable): The blocksize to select the element on each feature map Returns: Variable: The output LoDtensor. Raises: - TypeError: stride type must be a long. + TypeError: blocksize type must be a long. Examples: .. code-block:: python @@ -7515,13 +7515,13 @@ def space_to_depth(x, stride, name=None): data = fluid.layers.data( name='data', shape=[1, 4, 2, 2], dtype='float32') space_to_depthed = fluid.layers.space_to_depth( - x=data, stride=2) + x=data, blocksize=2) """ helper = LayerHelper("space_to_depth", **locals()) - if not (isinstance(stride, int)): - raise ValueError("stride must be a python Int") + if not (isinstance(blocksize, int)): + raise ValueError("blocksize must be a python Int") if name is None: out = helper.create_variable_for_type_inference( @@ -7533,7 +7533,7 @@ def space_to_depth(x, stride, name=None): helper.append_op( type="space_to_depth", inputs={"X": x}, - attrs={"stride": stride}, + attrs={"blocksize": blocksize}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py index 36c8cd1119..5fdad44f12 100644 --- a/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py +++ b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py @@ -21,8 +21,8 @@ from op_test import OpTest class TestSpaceToDepthOp(OpTest): @staticmethod - def helper(in_, width, height, channel, batch, stride, forward, out_): - channel_out = channel // (stride * stride) + def helper(in_, width, height, channel, batch, blocksize, forward, out_): + channel_out = channel // (blocksize * blocksize) for b in range(batch): for k in range(channel): for j in range(height): @@ -30,10 +30,10 @@ class TestSpaceToDepthOp(OpTest): in_index = i + width * (j + height * (k + channel * b)) channel2 = k % channel_out offset = k // channel_out - width2 = i * stride + offset % stride - height2 = j * stride + offset // stride - out_index = width2 + width * stride * ( - height2 + height * stride * + width2 = i * blocksize + offset % blocksize + height2 = j * blocksize + offset // blocksize + out_index = width2 + width * blocksize * ( + height2 + height * blocksize * (channel2 + channel_out * b)) if forward: out_[out_index] = in_[in_index] @@ -46,10 +46,10 @@ class TestSpaceToDepthOp(OpTest): self.op_type = "space_to_depth" self.inputs = {"X": self.x} self.helper(self.x_1d, self.x.shape[3], self.x.shape[2], - self.x.shape[1], self.x.shape[0], self.stride, self.forward, - self.out_1d) + self.x.shape[1], self.x.shape[0], self.blocksize, + self.forward, self.out_1d) self.out = np.reshape(self.out_1d, self.infered_shape) - self.attrs = {"stride": self.stride} + self.attrs = {"blocksize": self.blocksize} self.outputs = {"Out": self.out} def init_data(self): @@ -57,7 +57,7 @@ class TestSpaceToDepthOp(OpTest): self.infered_shape = (32, 48, 3, 3) self.one_d_len = 32 * 48 * 3 * 3 - self.stride = 2 + self.blocksize = 2 self.x = np.random.random(self.ori_shape).astype('float32') self.x_1d = np.reshape(self.x, self.one_d_len) self.out = np.zeros(self.infered_shape).astype('float32') @@ -81,7 +81,7 @@ class TestSpaceToDepthOpBasic(TestSpaceToDepthOp): self.infered_shape = (32, 32, 3, 3) self.one_d_len = 32 * 32 * 3 * 3 - self.stride = 2 + self.blocksize = 2 self.x = np.random.random(self.ori_shape).astype('float32') self.x_1d = np.reshape(self.x, self.one_d_len) self.out = np.zeros(self.infered_shape).astype('float32') @@ -95,7 +95,7 @@ class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp): self.infered_shape = (32, 32, 3, 3) self.one_d_len = 32 * 32 * 3 * 3 - self.stride = 2 + self.blocksize = 2 self.x = np.random.random(self.ori_shape).astype('float64') self.x_1d = np.reshape(self.x, self.one_d_len) self.out = np.zeros(self.infered_shape).astype('float64') @@ -109,7 +109,7 @@ class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp): self.infered_shape = (32, 81, 2, 2) self.one_d_len = 32 * 81 * 2 * 2 - self.stride = 3 + self.blocksize = 3 self.x = np.random.random(self.ori_shape).astype('float32') self.x_1d = np.reshape(self.x, self.one_d_len) self.out = np.zeros(self.infered_shape).astype('float32') @@ -123,7 +123,7 @@ class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp): self.infered_shape = (32, 81, 3, 2) self.one_d_len = 32 * 81 * 3 * 2 - self.stride = 3 + self.blocksize = 3 self.x = np.random.random(self.ori_shape).astype('float32') self.x_1d = np.reshape(self.x, self.one_d_len) self.out = np.zeros(self.infered_shape).astype('float32') From 0e3038680b607ce441d285c4fd3a4e4cb75cad16 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Mon, 29 Oct 2018 06:35:30 +0000 Subject: [PATCH 13/14] test=develop --- paddle/fluid/API.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d317117bcf..1f7e17d327 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -174,7 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) From 45565784bff06ced07829071a3be30dce5871c64 Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 1 Nov 2018 08:45:53 +0000 Subject: [PATCH 14/14] test=develop --- python/paddle/fluid/layers/nn.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 90af75a24f..69f0f8dc89 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7506,9 +7506,16 @@ def space_to_depth(x, blocksize, name=None): space_to_depth is used to This operation is useful for resizing the activations between convolutions (but keeping all data) + - Non-overlapping blocks of size block_size x block size are rearranged into depth at each location. + - The depth of the output tensor is block_size * block_size * input channel + - The Y, X coordinates within each block of the input become the high order component of the output channel index + - channel should be divisible by square of blocksize + - height, width should be divsible by blocksize + + Args: x(variable): The input LoDtensor. - blocksize(variable): The blocksize to select the element on each feature map + blocksize(variable): The blocksize to select the element on each feature map should be > 2 Returns: Variable: The output LoDtensor.