From 70729ad6416eecb8cb7f4e1d648f83e92bb73bdf Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 29 Jun 2018 13:13:05 +0000 Subject: [PATCH 1/8] Add Unsqueeze Operator Framework, not finshed --- paddle/fluid/operators/unsqueeze_op.cc | 148 ++++++++++++++++++ paddle/fluid/operators/unsqueeze_op.cu | 30 ++++ paddle/fluid/operators/unsqueeze_op.h | 72 +++++++++ .../tests/unittests/test_unsqueeze_op.py | 98 ++++++++++++ 4 files changed, 348 insertions(+) create mode 100644 paddle/fluid/operators/unsqueeze_op.cc create mode 100644 paddle/fluid/operators/unsqueeze_op.cu create mode 100644 paddle/fluid/operators/unsqueeze_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_unsqueeze_op.py diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc new file mode 100644 index 0000000000..8d2a186685 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unsqueeze_op.h" +#include +#include + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class UnsqueezeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnsqueezeOp should not be null."); + + const auto& x_dims = ctx->GetInputDim("X"); + const auto& axes = ctx->Attrs().Get>("axes"); + // Check output tensor dims (<9). + PADDLE_ENFORCE_LE(x_dims.size() + axes.size(), 9, + "Invalid dimnesions, dynamic dimensions must have " + "between [1, 9] dimensions."); + // Check the range of unsqueeze aixs. + for (int a : axes) { + PADDLE_ENFORCE_LT(a, static_cast(x_dims.size() + axes.size()), + "The axis must be less than output tensor's rank."); + } + + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + } + + static framework::DDim GetOutputShape(const std::vector unsqueeze_dims, + const framework::DDim& in_dims) { + int out_dims_size = in_dims.size() + unsqueeze_dims.size(); + bool should_unsqueeze[9] = {false}; + + // Determines the dimensions should be unsqueezed in output tensor after. + for (unsigned int idx = 0; idx < unsqueeze_dims.size(); ++idx) { + int current = unsqueeze_dims[idx] < 0 + ? unsqueeze_dims[idx] + out_dims_size + : unsqueeze_dims[idx]; + // Check current index. + PADDLE_ENFORCE_GE(current, 0, + "Invaild axis, negative axis is out of range."); + should_unsqueeze[idx] = true; + } + + // Make output dimensions + std::vector output_shape(out_dims_size, 0); + for (int in_idx = 0, out_idx = 0; out_idx < out_dims_size; ++out_idx) { + if (!should_unsqueeze[out_idx]) { + output_shape[out_idx] = in_dims[in_idx++]; + } else { + output_shape[out_idx] = 1; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); + AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); + AddAttr>("axes", + "(std::vector). List of positive integers," + " indicate the dimensions to be inserted"); + AddAttr( + "inplace", + "(default: false) Unsqueeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Unsqueeze Operator. + + Insert single-dimensional entries to the shape of a tensor. + Takes one required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] + )DOC"); + } +}; + +class UnsqueezeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Output(Out@GRAD) of UnsqueezeGradOp should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); +REGISTER_OP_CPU_KERNEL( + unsqueeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_CPU_KERNEL( + unsqueeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.cu b/paddle/fluid/operators/unsqueeze_op.cu new file mode 100644 index 0000000000..891f6cc548 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/fluid/operators/unsqueeze_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + squeeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_CUDA_KERNEL( + squeeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h new file mode 100644 index 0000000000..aa45fb3113 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class UnsqueezeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out = ctx.Output("Out"); + auto *in = ctx.Input("X"); + + framework::DDim out_dims = out->dims(); + + bool inplace = ctx.Attr("inplace"); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace()); + framework::TensorCopySync(*in, ctx.GetPlace(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } + } +}; + +template +class UnsqueezeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + + d_x->mutable_data(ctx.GetPlace()); + bool inplace = ctx.Attr("inplace"); + + auto in_dims = d_x->dims(); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py new file mode 100644 index 0000000000..273a2c075f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -0,0 +1,98 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestSqueezeOp1(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, 2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: There is mins axis. +class TestSqueezeOp2(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, -2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. +class TestUnsqueezeOpInplace1(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, 2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inplace": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. There is mins axis. +class TestUnsqueezeOpInplace2(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, -2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == "__main__": + unittest.main() From e402496238cd1f2132f7a2e4f354404acdf6dcbb Mon Sep 17 00:00:00 2001 From: chenweihang Date: Mon, 2 Jul 2018 03:23:27 +0000 Subject: [PATCH 2/8] complete unsqueeze op and related unittest. --- paddle/fluid/operators/unsqueeze_op.cc | 113 ++++++++++++------ paddle/fluid/operators/unsqueeze_op.cu | 4 +- .../tests/unittests/test_unsqueeze_op.py | 98 +++++++++++++-- 3 files changed, 167 insertions(+), 48 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 8d2a186685..373dac8bab 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -32,42 +32,85 @@ class UnsqueezeOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UnsqueezeOp should not be null."); - const auto& x_dims = ctx->GetInputDim("X"); const auto& axes = ctx->Attrs().Get>("axes"); - // Check output tensor dims (<9). - PADDLE_ENFORCE_LE(x_dims.size() + axes.size(), 9, - "Invalid dimnesions, dynamic dimensions must have " - "between [1, 9] dimensions."); - // Check the range of unsqueeze aixs. - for (int a : axes) { - PADDLE_ENFORCE_LT(a, static_cast(x_dims.size() + axes.size()), - "The axis must be less than output tensor's rank."); + PADDLE_ENFORCE(!axes.empty(), + "The unsqueeze axes information must be set by Attr(axes)."); + + const auto& x_dims = ctx->GetInputDim("X"); + // Validity Check: input tensor dims (<6). + PADDLE_ENFORCE(x_dims.size() < 6, + "Invalid dimensions, dynamic dimensions should within " + "[0, 5] dimensions (Eigen limit)."); + // Validity Check: the range of unsqueeze aixs. + // TODO(chenweihang): Don't consider negative axis?. + for (unsigned int idx = 0; idx < axes.size(); ++idx) { + PADDLE_ENFORCE(axes[idx] < 6, + "Invalid dimensions, input axis should within " + "[0, 5] dimensions (Eigen limit)."); } auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); } - static framework::DDim GetOutputShape(const std::vector unsqueeze_dims, + static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim& in_dims) { - int out_dims_size = in_dims.size() + unsqueeze_dims.size(); - bool should_unsqueeze[9] = {false}; - - // Determines the dimensions should be unsqueezed in output tensor after. - for (unsigned int idx = 0; idx < unsqueeze_dims.size(); ++idx) { - int current = unsqueeze_dims[idx] < 0 - ? unsqueeze_dims[idx] + out_dims_size - : unsqueeze_dims[idx]; - // Check current index. - PADDLE_ENFORCE_GE(current, 0, - "Invaild axis, negative axis is out of range."); - should_unsqueeze[idx] = true; + /* + * STL version + * Test Error! don't know why?. + std::vector output_shape; + + // Contruct base output shape + for(int idx = 0; idx < in_dims.size(); ++idx) { + output_shape.emplace_back(in_dims[idx]); + } + // Validity Check: output dimensions limit. + PADDLE_ENFORCE(unsqz_dims.size() + output_shape.size() < 6, + "The Attr(axes) size is too large. The output shape should " + "be less than 6 (Eigne limit)."); + // Insert the unsqueeze axis in turn. + auto it = output_shape.begin(); + for (int axis : unsqz_dims) { + int cur = axis < 0 ? (axis + output_shape.size() + 1) + : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE(cur >= 0 && cur <= static_cast(output_shape.size()), + "The unsqueeze dims must be within range of current + rank."); + output_shape.emplace(it + axis, 1); + } + */ + + unsigned int unsqz_mask = 0; + unsigned int front = 0, back = 0; + int output_dims_size = in_dims.size(); + + // Simulate insert by bit calc. + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + output_dims_size + 1 : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE( + cur >= 0 && cur <= output_dims_size, + "The unsqueeze dims must be within range of current rank."); + // Save the front part. + front = unsqz_mask & ((1 << axis) - 1); + // Move the back part. + back = unsqz_mask & ~((1 << axis) - 1); + back <<= 1; + // Merge two part. + back |= (1 << axis); + unsqz_mask = front | back; + // Add the output size. + output_dims_size++; + // Validity Check: rank range. + PADDLE_ENFORCE(output_dims_size < 6, + "The output tensor's rank should be less than 6."); } - // Make output dimensions - std::vector output_shape(out_dims_size, 0); - for (int in_idx = 0, out_idx = 0; out_idx < out_dims_size; ++out_idx) { - if (!should_unsqueeze[out_idx]) { + // Make output shape + std::vector output_shape(output_dims_size, 0); + for (int in_idx = 0, out_idx = 0; out_idx < output_dims_size; ++out_idx) { + if ((unsqz_mask & (1 << out_idx)) == 0) { output_shape[out_idx] = in_dims[in_idx++]; } else { output_shape[out_idx] = 1; @@ -94,15 +137,15 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "tensor is created, and its data are copied from Input(x).") .SetDefault(false); AddComment(R"DOC( - Unsqueeze Operator. - - Insert single-dimensional entries to the shape of a tensor. - Takes one required argument axes, a list of dimensions that will be inserted. - Dimension indices in axes are as seen in the output tensor. - - For example: - Given a tensor such that tensor with shape [3, 4, 5], - then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] + Unsqueeze Operator. + + Insert single-dimensional entries to the shape of a tensor. + Takes one required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] )DOC"); } }; diff --git a/paddle/fluid/operators/unsqueeze_op.cu b/paddle/fluid/operators/unsqueeze_op.cu index 891f6cc548..4d111190cd 100644 --- a/paddle/fluid/operators/unsqueeze_op.cu +++ b/paddle/fluid/operators/unsqueeze_op.cu @@ -18,12 +18,12 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - squeeze, ops::UnsqueezeKernel, + unsqueeze, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel, ops::UnsqueezeKernel); REGISTER_OP_CUDA_KERNEL( - squeeze_grad, + unsqueeze_grad, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, ops::UnsqueezeGradKernel, diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 273a2c075f..af273ca5a1 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ from op_test import OpTest # Correct: General. -class TestSqueezeOp1(OpTest): +class TestUnsqueezeOp(OpTest): def setUp(self): ori_shape = (3, 5) axes = (0, 2) @@ -38,7 +38,7 @@ class TestSqueezeOp1(OpTest): # Correct: There is mins axis. -class TestSqueezeOp2(OpTest): +class TestUnsqueezeOp2(OpTest): def setUp(self): ori_shape = (3, 5) axes = (0, -2) @@ -56,6 +56,82 @@ class TestSqueezeOp2(OpTest): self.check_grad(["X"], "Out") +# Correct: There is duplicated axis. +class TestUnsqueezeOp3(OpTest): + def setUp(self): + ori_shape = (3, 2, 5) + axes = (0, 3, 3) + new_shape = (1, 3, 2, 1, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Error: Output dimension is error. +class TestUnsqueezeOp4(OpTest): + def setUp(self): + ori_shape = (3, 2, 5) + axes = (0, 3) + new_shape = (1, 3, 2, 2, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Error: Input axes is invalid case 1. +class TestUnsqueezeOp5(OpTest): + def setUp(self): + ori_shape = (3, 2, 5) + axes = (0, 5) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Error: Input axes is invalid case 2. +class TestUnsqueezeOp5(OpTest): + def setUp(self): + ori_shape = (3, 2, 5) + axes = (0, 2, 10) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + # Correct: Inplace. class TestUnsqueezeOpInplace1(OpTest): def setUp(self): @@ -75,23 +151,23 @@ class TestUnsqueezeOpInplace1(OpTest): self.check_grad(["X"], "Out") -# Correct: Inplace. There is mins axis. +# Correct: Inplace. There is duplicated axis. class TestUnsqueezeOpInplace2(OpTest): def setUp(self): - ori_shape = (3, 5) - axes = (0, -2) - new_shape = (1, 3, 1, 5) + ori_shape = (3, 2, 5) + axes = (0, 3, 3) + new_shape = (1, 3, 2, 1, 1, 5) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} self.attrs = {"axes": axes, "inpalce": True} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") if __name__ == "__main__": From ca1577939444cf702d4b131ac0afa8bfbad0211d Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 3 Jul 2018 03:58:48 +0000 Subject: [PATCH 3/8] rewrite, use reshape op in unsqueeze op, test passed --- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/unsqueeze_op.cc | 146 +++++++-------- paddle/fluid/operators/unsqueeze_op.cu | 30 ---- paddle/fluid/operators/unsqueeze_op.h | 72 -------- .../tests/unittests/test_unsqueeze_op.py | 167 ++++++++++++------ 5 files changed, 185 insertions(+), 231 deletions(-) delete mode 100644 paddle/fluid/operators/unsqueeze_op.cu delete mode 100644 paddle/fluid/operators/unsqueeze_op.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ab1d214333..50f5f34021 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(unsqueeze_op DEPS reshape_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 373dac8bab..c503988676 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -12,41 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/unsqueeze_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { -using framework::OpKernelType; -using framework::Tensor; - -class UnsqueezeOp : public framework::OperatorWithKernel { +class UnsqueezeOpInferShape : public framework::InferShapeBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { + void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnsqueezeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UnsqueezeOp should not be null."); - const auto& axes = ctx->Attrs().Get>("axes"); + const auto &axes = ctx->Attrs().Get>("axes"); PADDLE_ENFORCE(!axes.empty(), "The unsqueeze axes information must be set by Attr(axes)."); - const auto& x_dims = ctx->GetInputDim("X"); + const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). - PADDLE_ENFORCE(x_dims.size() < 6, + PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, "Invalid dimensions, dynamic dimensions should within " - "[0, 5] dimensions (Eigen limit)."); + "[1, 6] dimensions (Eigen limit)."); // Validity Check: the range of unsqueeze aixs. - // TODO(chenweihang): Don't consider negative axis?. - for (unsigned int idx = 0; idx < axes.size(); ++idx) { - PADDLE_ENFORCE(axes[idx] < 6, + for (int axis : axes) { + PADDLE_ENFORCE(axis < 6, "Invalid dimensions, input axis should within " - "[0, 5] dimensions (Eigen limit)."); + "[1, 6] dimensions (Eigen limit)."); } auto out_dims = GetOutputShape(axes, x_dims); @@ -54,33 +48,7 @@ class UnsqueezeOp : public framework::OperatorWithKernel { } static framework::DDim GetOutputShape(const std::vector unsqz_dims, - const framework::DDim& in_dims) { - /* - * STL version - * Test Error! don't know why?. - std::vector output_shape; - - // Contruct base output shape - for(int idx = 0; idx < in_dims.size(); ++idx) { - output_shape.emplace_back(in_dims[idx]); - } - // Validity Check: output dimensions limit. - PADDLE_ENFORCE(unsqz_dims.size() + output_shape.size() < 6, - "The Attr(axes) size is too large. The output shape should " - "be less than 6 (Eigne limit)."); - // Insert the unsqueeze axis in turn. - auto it = output_shape.begin(); - for (int axis : unsqz_dims) { - int cur = axis < 0 ? (axis + output_shape.size() + 1) - : axis; - // Vaildity Check: the axis bound - PADDLE_ENFORCE(cur >= 0 && cur <= static_cast(output_shape.size()), - "The unsqueeze dims must be within range of current - rank."); - output_shape.emplace(it + axis, 1); - } - */ - + const framework::DDim &in_dims) { unsigned int unsqz_mask = 0; unsigned int front = 0, back = 0; int output_dims_size = in_dims.size(); @@ -93,17 +61,17 @@ class UnsqueezeOp : public framework::OperatorWithKernel { cur >= 0 && cur <= output_dims_size, "The unsqueeze dims must be within range of current rank."); // Save the front part. - front = unsqz_mask & ((1 << axis) - 1); + front = unsqz_mask & ((1 << cur) - 1); // Move the back part. - back = unsqz_mask & ~((1 << axis) - 1); + back = unsqz_mask & ~((1 << cur) - 1); back <<= 1; // Merge two part. - back |= (1 << axis); + back |= (1 << cur); unsqz_mask = front | back; // Add the output size. output_dims_size++; // Validity Check: rank range. - PADDLE_ENFORCE(output_dims_size < 6, + PADDLE_ENFORCE(output_dims_size <= 6, "The output tensor's rank should be less than 6."); } @@ -121,6 +89,31 @@ class UnsqueezeOp : public framework::OperatorWithKernel { } }; +class UnsqueezeOp : public framework::OperatorBase { + public: + UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape op. + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -150,42 +143,49 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class UnsqueezeGradOp : public framework::OperatorWithKernel { +class UnsqueezeGradInferShape : public framework::InferShapeBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of UnsqueezeGradOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Output(Out@GRAD) of UnsqueezeGradOp should not be null."); + void operator()(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", framework::GradVarName("X")); } +}; - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); +class UnsqueezeGradOp : public framework::OperatorBase { + public: + UnsqueezeGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); } }; } // namespace operators } // namespace paddle +// Tell linker to use reshape op. +USE_OP(reshape); + namespace ops = paddle::operators; REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + ops::UnsqueezeOpInferShape, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); -REGISTER_OP_CPU_KERNEL( - unsqueeze, ops::UnsqueezeKernel, - ops::UnsqueezeKernel, - ops::UnsqueezeKernel, - ops::UnsqueezeKernel); -REGISTER_OP_CPU_KERNEL( - unsqueeze_grad, - ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp, + ops::UnsqueezeGradInferShape); diff --git a/paddle/fluid/operators/unsqueeze_op.cu b/paddle/fluid/operators/unsqueeze_op.cu deleted file mode 100644 index 4d111190cd..0000000000 --- a/paddle/fluid/operators/unsqueeze_op.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#define EIGEN_USE_GPU - -#include "paddle/fluid/operators/unsqueeze_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - unsqueeze, ops::UnsqueezeKernel, - ops::UnsqueezeKernel, - ops::UnsqueezeKernel, - ops::UnsqueezeKernel); -REGISTER_OP_CUDA_KERNEL( - unsqueeze_grad, - ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel, - ops::UnsqueezeGradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h deleted file mode 100644 index aa45fb3113..0000000000 --- a/paddle/fluid/operators/unsqueeze_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class UnsqueezeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *out = ctx.Output("Out"); - auto *in = ctx.Input("X"); - - framework::DDim out_dims = out->dims(); - - bool inplace = ctx.Attr("inplace"); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(ctx.GetPlace()); - framework::TensorCopySync(*in, ctx.GetPlace(), out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*in); - out->Resize(out_dims); - } - } -}; - -template -class UnsqueezeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *d_out = ctx.Input(framework::GradVarName("Out")); - auto *d_x = ctx.Output(framework::GradVarName("X")); - - d_x->mutable_data(ctx.GetPlace()); - bool inplace = ctx.Attr("inplace"); - - auto in_dims = d_x->dims(); - if (!inplace) { - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - ctx.device_context().Wait(); - d_x->Resize(in_dims); - } else { - d_x->ShareDataWith(*d_out); - d_x->Resize(in_dims); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index af273ca5a1..eff90f4618 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -27,7 +27,7 @@ class TestUnsqueezeOp(OpTest): self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": False} + self.attrs = {"axes": axes, "inplace": False} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} def test_check_output(self): @@ -37,23 +37,42 @@ class TestUnsqueezeOp(OpTest): self.check_grad(["X"], "Out") -# Correct: There is mins axis. +# Correct: Single input index. +class TestUnsqueezeOp1(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (-1, ) + new_shape = (3, 5, 1) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inplace": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Mixed input axis. class TestUnsqueezeOp2(OpTest): def setUp(self): ori_shape = (3, 5) - axes = (0, -2) - new_shape = (1, 3, 1, 5) + axes = (0, -1) + new_shape = (1, 3, 5, 1) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": False} + self.attrs = {"axes": axes, "inplace": False} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") # Correct: There is duplicated axis. @@ -65,83 +84,84 @@ class TestUnsqueezeOp3(OpTest): self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": False} + self.attrs = {"axes": axes, "inplace": False} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") -# Error: Output dimension is error. -class TestUnsqueezeOp4(OpTest): +# Correct: Inplace. +class TestUnsqueezeOpInplace1(OpTest): def setUp(self): - ori_shape = (3, 2, 5) - axes = (0, 3) - new_shape = (1, 3, 2, 2, 5) + ori_shape = (3, 5) + axes = (0, 2) + new_shape = (1, 3, 1, 5) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": False} + self.attrs = {"axes": axes, "inplace": True} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") -# Error: Input axes is invalid case 1. -class TestUnsqueezeOp5(OpTest): +# Correct: Inplace. There is mins index. +class TestUnsqueezeOpInplace2(OpTest): def setUp(self): - ori_shape = (3, 2, 5) - axes = (0, 5) + ori_shape = (3, 5) + axes = (0, -2) new_shape = (1, 3, 1, 5) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": False} + self.attrs = {"axes": axes, "inplace": True} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") -# Error: Input axes is invalid case 2. -class TestUnsqueezeOp5(OpTest): +# Correct: Inplace. There is duplicated axis. +class TestUnsqueezeOpInplace3(OpTest): def setUp(self): ori_shape = (3, 2, 5) - axes = (0, 2, 10) - new_shape = (1, 3, 1, 5) + axes = (0, 3, 3) + new_shape = (1, 3, 2, 1, 1, 5) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": False} + self.attrs = {"axes": axes, "inplace": True} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") + def test_check_grad(self): + self.check_grad(["X"], "Out") -# Correct: Inplace. -class TestUnsqueezeOpInplace1(OpTest): +''' +# Error: Output dimension is error. +class TestUnsqueezeOp4(OpTest): def setUp(self): ori_shape = (3, 5) - axes = (0, 2) - new_shape = (1, 3, 1, 5) + axes = (0, 3) + new_shape = (1, 3, 1, 1, 5) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} + self.attrs = {"axes": axes, "inplace": False} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} def test_check_output(self): @@ -150,25 +170,60 @@ class TestUnsqueezeOpInplace1(OpTest): def test_check_grad(self): self.check_grad(["X"], "Out") - -# Correct: Inplace. There is duplicated axis. -class TestUnsqueezeOpInplace2(OpTest): +# Error: Input axis is large than output range. +class TestUnsqueezeOp5(OpTest): def setUp(self): - ori_shape = (3, 2, 5) - axes = (0, 3, 3) - new_shape = (1, 3, 2, 1, 1, 5) + ori_shape = (3, 5) + axes = (0, 4) + new_shape = (1, 3, 5, 1) self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inpalce": True} + self.attrs = {"axes": axes, "inplace": False} self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - def test_check_output(self): - self.check_output() + def test_check_output(self): + self.check_output() def test_check_grad(self): self.check_grad(["X"], "Out") +# Error: Input axes is large than Eigen limit. +class TestUnsqueezeOp6(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, 2, 10) + new_shape = (1, 3, 1, 5, 1) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inplace": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + +# Error: Input axes size is large than Eigen limit. +class TestUnsqueezeOp7(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, 2, 2, 2, 2, 2) + new_shape = (1, 3, 1, 1, 5, 1) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inplace": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") +''' if __name__ == "__main__": unittest.main() From 49b2cf5feee66010c6598f8d4fc49f1fc1f29048 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Wed, 4 Jul 2018 09:44:39 +0000 Subject: [PATCH 4/8] adjust some code based reviewer's advice --- paddle/fluid/operators/unsqueeze_op.cc | 30 ++- .../tests/unittests/test_unsqueeze_op.py | 216 ++++-------------- 2 files changed, 60 insertions(+), 186 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index c503988676..62e45468ab 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,15 +36,13 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, "Invalid dimensions, dynamic dimensions should within " "[1, 6] dimensions (Eigen limit)."); - // Validity Check: the range of unsqueeze aixs. - for (int axis : axes) { - PADDLE_ENFORCE(axis < 6, - "Invalid dimensions, input axis should within " - "[1, 6] dimensions (Eigen limit)."); - } - auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } } static framework::DDim GetOutputShape(const std::vector unsqz_dims, @@ -102,6 +100,8 @@ class UnsqueezeOp : public framework::OperatorBase { auto &axes = Attr>("axes"); auto x_dims = scope.FindVar(Input("X"))->Get().dims(); auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); + // auto out_dims = + // scope.FindVar(Output("Out"))->Get().dims(); framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims); @@ -121,7 +121,19 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", "(std::vector). List of positive integers," - " indicate the dimensions to be inserted"); + " indicate the dimensions to be inserted") + .AddCustomChecker([](const std::vector &axes) { + // Validity Check: axes dims (<6). + PADDLE_ENFORCE(static_cast(axes.size()) < 6, + "Invalid dimensions, dynamic dimensions should within " + "[1, 6] dimensions (Eigen limit)."); + // Validity Check: the range of unsqueeze aixs. + for (int axis : axes) { + PADDLE_ENFORCE(axis < 6, + "Invalid dimensions, input axis should within " + "[1, 6] dimensions (Eigen limit)."); + } + }); AddAttr( "inplace", "(default: false) Unsqueeze the source tensor's shape without " diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index eff90f4618..62dc6fcb9e 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -21,14 +21,11 @@ from op_test import OpTest # Correct: General. class TestUnsqueezeOp(OpTest): def setUp(self): - ori_shape = (3, 5) - axes = (0, 2) - new_shape = (1, 3, 1, 5) - + self.init_test_case() self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} + self.attrs = {"axes": self.axes, "inplace": False} + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} def test_check_output(self): self.check_output() @@ -36,194 +33,59 @@ class TestUnsqueezeOp(OpTest): def test_check_grad(self): self.check_grad(["X"], "Out") + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (1, 2) + self.new_shape = (3, 1, 1, 5) -# Correct: Single input index. -class TestUnsqueezeOp1(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (-1, ) - new_shape = (3, 5, 1) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - def test_check_grad(self): - self.check_grad(["X"], "Out") +# Correct: Single input index. +class TestUnsqueezeOp1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (-1, ) + self.new_shape = (3, 5, 1) # Correct: Mixed input axis. -class TestUnsqueezeOp2(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, -1) - new_shape = (1, 3, 5, 1) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestUnsqueezeOp2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -1) + self.new_shape = (1, 3, 5, 1) # Correct: There is duplicated axis. -class TestUnsqueezeOp3(OpTest): - def setUp(self): - ori_shape = (3, 2, 5) - axes = (0, 3, 3) - new_shape = (1, 3, 2, 1, 1, 5) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestUnsqueezeOp3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) # Correct: Inplace. -class TestUnsqueezeOpInplace1(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, 2) - new_shape = (1, 3, 1, 5) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestUnsqueezeOpInplace1(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, 2) + self.new_shape = (1, 3, 1, 5) # Correct: Inplace. There is mins index. -class TestUnsqueezeOpInplace2(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, -2) - new_shape = (1, 3, 1, 5) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") +class TestUnsqueezeOpInplace2(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 5) + self.axes = (0, -2) + self.new_shape = (1, 3, 1, 5) # Correct: Inplace. There is duplicated axis. -class TestUnsqueezeOpInplace3(OpTest): - def setUp(self): - ori_shape = (3, 2, 5) - axes = (0, 3, 3) - new_shape = (1, 3, 2, 1, 1, 5) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": True} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") - - -''' -# Error: Output dimension is error. -class TestUnsqueezeOp4(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, 3) - new_shape = (1, 3, 1, 1, 5) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") - -# Error: Input axis is large than output range. -class TestUnsqueezeOp5(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, 4) - new_shape = (1, 3, 5, 1) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} +class TestUnsqueezeOpInplace3(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (0, 3, 3) + self.new_shape = (1, 3, 2, 1, 1, 5) - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") - -# Error: Input axes is large than Eigen limit. -class TestUnsqueezeOp6(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, 2, 10) - new_shape = (1, 3, 1, 5, 1) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") - -# Error: Input axes size is large than Eigen limit. -class TestUnsqueezeOp7(OpTest): - def setUp(self): - ori_shape = (3, 5) - axes = (0, 2, 2, 2, 2, 2) - new_shape = (1, 3, 1, 1, 5, 1) - - self.op_type = "unsqueeze" - self.inputs = {"X": np.random.random(ori_shape).astype("float32")} - self.attrs = {"axes": axes, "inplace": False} - self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Out") -''' if __name__ == "__main__": unittest.main() From 80126a7496cc1d0c4568d7b8e5cc92c1f8bf5904 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 6 Jul 2018 06:38:24 +0000 Subject: [PATCH 5/8] small fix based reviewer's advice --- paddle/fluid/operators/unsqueeze_op.cc | 6 +++--- .../fluid/tests/unittests/test_unsqueeze_op.py | 14 +++++++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 62e45468ab..d950da6a75 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -28,9 +28,6 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { "Output(Out) of UnsqueezeOp should not be null."); const auto &axes = ctx->Attrs().Get>("axes"); - PADDLE_ENFORCE(!axes.empty(), - "The unsqueeze axes information must be set by Attr(axes)."); - const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, @@ -123,6 +120,9 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "(std::vector). List of positive integers," " indicate the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { + PADDLE_ENFORCE( + !axes.empty(), + "The unsqueeze axes information must be set by Attr(axes)."); // Validity Check: axes dims (<6). PADDLE_ENFORCE(static_cast(axes.size()) < 6, "Invalid dimensions, dynamic dimensions should within " diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index 62dc6fcb9e..d19d4e525a 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -24,7 +24,7 @@ class TestUnsqueezeOp(OpTest): self.init_test_case() self.op_type = "unsqueeze" self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")} - self.attrs = {"axes": self.axes, "inplace": False} + self.init_attrs() self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} def test_check_output(self): @@ -38,6 +38,9 @@ class TestUnsqueezeOp(OpTest): self.axes = (1, 2) self.new_shape = (3, 1, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": False} + # Correct: Single input index. class TestUnsqueezeOp1(TestUnsqueezeOp): @@ -70,6 +73,9 @@ class TestUnsqueezeOpInplace1(TestUnsqueezeOp): self.axes = (0, 2) self.new_shape = (1, 3, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + # Correct: Inplace. There is mins index. class TestUnsqueezeOpInplace2(TestUnsqueezeOp): @@ -78,6 +84,9 @@ class TestUnsqueezeOpInplace2(TestUnsqueezeOp): self.axes = (0, -2) self.new_shape = (1, 3, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + # Correct: Inplace. There is duplicated axis. class TestUnsqueezeOpInplace3(TestUnsqueezeOp): @@ -86,6 +95,9 @@ class TestUnsqueezeOpInplace3(TestUnsqueezeOp): self.axes = (0, 3, 3) self.new_shape = (1, 3, 2, 1, 1, 5) + def init_attrs(self): + self.attrs = {"axes": self.axes, "inplace": True} + if __name__ == "__main__": unittest.main() From 5f89272c89befd113d1fa44e9055f47bcceb455e Mon Sep 17 00:00:00 2001 From: chenweihang Date: Mon, 9 Jul 2018 06:08:55 +0000 Subject: [PATCH 6/8] change the bit insert to array insert for understandability --- paddle/fluid/operators/unsqueeze_op.cc | 57 ++++++++----------- .../tests/unittests/test_unsqueeze_op.py | 8 +++ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index d950da6a75..960bc6f241 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -44,39 +44,37 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim &in_dims) { - unsigned int unsqz_mask = 0; - unsigned int front = 0, back = 0; - int output_dims_size = in_dims.size(); + int output_size = in_dims.size() + unsqz_dims.size(); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE(output_size <= 6, + "The output tensor's rank should be less than 6."); - // Simulate insert by bit calc. for (int axis : unsqz_dims) { - int cur = axis < 0 ? axis + output_dims_size + 1 : axis; + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; // Vaildity Check: the axis bound PADDLE_ENFORCE( - cur >= 0 && cur <= output_dims_size, + cur >= 0 && cur <= cur_output_size, "The unsqueeze dims must be within range of current rank."); - // Save the front part. - front = unsqz_mask & ((1 << cur) - 1); - // Move the back part. - back = unsqz_mask & ~((1 << cur) - 1); - back <<= 1; - // Merge two part. - back |= (1 << cur); - unsqz_mask = front | back; + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; // Add the output size. - output_dims_size++; - // Validity Check: rank range. - PADDLE_ENFORCE(output_dims_size <= 6, - "The output tensor's rank should be less than 6."); + cur_output_size++; } // Make output shape - std::vector output_shape(output_dims_size, 0); - for (int in_idx = 0, out_idx = 0; out_idx < output_dims_size; ++out_idx) { - if ((unsqz_mask & (1 << out_idx)) == 0) { + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { output_shape[out_idx] = in_dims[in_idx++]; - } else { - output_shape[out_idx] = 1; } } @@ -86,10 +84,7 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { class UnsqueezeOp : public framework::OperatorBase { public: - UnsqueezeOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + using OperatorBase::OperatorBase; private: void RunImpl(const framework::Scope &scope, @@ -97,8 +92,6 @@ class UnsqueezeOp : public framework::OperatorBase { auto &axes = Attr>("axes"); auto x_dims = scope.FindVar(Input("X"))->Get().dims(); auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims); - // auto out_dims = - // scope.FindVar(Output("Out"))->Get().dims(); framework::AttributeMap attrs; attrs["shape"] = framework::vectorize2int(out_dims); @@ -165,11 +158,7 @@ class UnsqueezeGradInferShape : public framework::InferShapeBase { class UnsqueezeGradOp : public framework::OperatorBase { public: - UnsqueezeGradOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + using OperatorBase::OperatorBase; private: void RunImpl(const framework::Scope &scope, diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py index d19d4e525a..7a4aa0a40b 100644 --- a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -66,6 +66,14 @@ class TestUnsqueezeOp3(TestUnsqueezeOp): self.new_shape = (1, 3, 2, 1, 1, 5) +# Correct: Reversed axes. +class TestUnsqueezeOp4(TestUnsqueezeOp): + def init_test_case(self): + self.ori_shape = (3, 2, 5) + self.axes = (3, 1, 1) + self.new_shape = (3, 1, 1, 2, 5, 1) + + # Correct: Inplace. class TestUnsqueezeOpInplace1(TestUnsqueezeOp): def init_test_case(self): From cef8dbc1f7867d013046227f8283ee249bda8a0f Mon Sep 17 00:00:00 2001 From: chenweihang Date: Tue, 10 Jul 2018 09:09:55 +0000 Subject: [PATCH 7/8] refine some messages and adjust data type --- paddle/fluid/operators/unsqueeze_op.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 5e089d77f4..da542aa852 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -30,9 +30,9 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { const auto &axes = ctx->Attrs().Get>("axes"); const auto &x_dims = ctx->GetInputDim("X"); // Validity Check: input tensor dims (<6). - PADDLE_ENFORCE(static_cast(x_dims.size()) <= 6, - "Invalid dimensions, dynamic dimensions should within " - "[1, 6] dimensions (Eigen limit)."); + PADDLE_ENFORCE(x_dims.size() <= 6, + "Invalid dimensions, the rank of Input(X) " + "should be in the range of [1, 6] (Eigen limit)"); auto out_dims = GetOutputShape(axes, x_dims); ctx->SetOutputDim("Out", out_dims); if (x_dims[0] == out_dims[0]) { @@ -44,8 +44,8 @@ class UnsqueezeOpInferShape : public framework::InferShapeBase { static framework::DDim GetOutputShape(const std::vector unsqz_dims, const framework::DDim &in_dims) { - int output_size = static_cast(in_dims.size() + unsqz_dims.size()); - int cur_output_size = static_cast(in_dims.size()); + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); std::vector output_shape(output_size, 0); // Validity Check: rank range. @@ -110,12 +110,11 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", - "(std::vector). List of positive integers," + "(std::vector). List of integers," " indicate the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { - PADDLE_ENFORCE( - !axes.empty(), - "The unsqueeze axes information must be set by Attr(axes)."); + PADDLE_ENFORCE(!axes.empty(), + "Invalid axes, The unsqueeze axes is empty."); // Validity Check: axes dims (<6). PADDLE_ENFORCE(static_cast(axes.size()) < 6, "Invalid dimensions, dynamic dimensions should within " From 3d159689583696757167c02815cd1859364649b2 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Wed, 11 Jul 2018 06:23:32 +0000 Subject: [PATCH 8/8] docs: fix some errors of description --- paddle/fluid/operators/unsqueeze_op.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index da542aa852..f2a15fdf57 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -111,19 +111,19 @@ class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); AddAttr>("axes", "(std::vector). List of integers," - " indicate the dimensions to be inserted") + " indicating the dimensions to be inserted") .AddCustomChecker([](const std::vector &axes) { PADDLE_ENFORCE(!axes.empty(), "Invalid axes, The unsqueeze axes is empty."); // Validity Check: axes dims (<6). PADDLE_ENFORCE(static_cast(axes.size()) < 6, - "Invalid dimensions, dynamic dimensions should within " - "[1, 6] dimensions (Eigen limit)."); + "Invalid dimensions, dynamic dimensions should be " + "within [1, 6] dimensions (Eigen limit)."); // Validity Check: the range of unsqueeze aixs. for (int axis : axes) { PADDLE_ENFORCE(axis < 6, - "Invalid dimensions, input axis should within " - "[1, 6] dimensions (Eigen limit)."); + "Invalid dimensions, input axis should be" + " within [1, 6] dimensions (Eigen limit)."); } }); AddAttr(