From 26cec83901dc443a60aef911c1ad2baf882eb474 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 30 Aug 2017 19:54:14 +0800 Subject: [PATCH 01/13] Add pad op --- paddle/operators/CMakeLists.txt | 1 + paddle/operators/pad_op.cc | 77 ++++++++++++++++++ paddle/operators/pad_op.cu | 21 +++++ paddle/operators/pad_op.h | 81 +++++++++++++++++++ paddle/pybind/CMakeLists.txt | 3 +- paddle/pybind/pybind.cc | 1 + .../paddle/v2/framework/tests/test_pad_op.py | 32 ++++++++ 7 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/pad_op.cc create mode 100644 paddle/operators/pad_op.cu create mode 100644 paddle/operators/pad_op.h create mode 100644 python/paddle/v2/framework/tests/test_pad_op.py diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f466dbc79a..1a759133e1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -72,3 +72,4 @@ op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) +op_library(pad_op SRCS pad_op.cc pad_op.cu) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc new file mode 100644 index 0000000000..f96d61669b --- /dev/null +++ b/paddle/operators/pad_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/pad_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class PadOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto dim0 = ctx.Input("X")->dims(); + auto dim1 = ctx.Output("Out")->dims(); + auto paddings = GetAttr>>("paddings"); + for (int i = 0; i < dim0.size(); ++i) { + dim1[i] = dim0[i] + paddings[i][0] + paddings[i][1]; + } + ctx.Output("Out")->Resize(dim1); + } +}; + +class MulOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of pad op"); + AddOutput("Out", "The output of pad op"); + AddComment(R"DOC( +Pad Operator. +)DOC"); + AddAttr>>( + "paddings", "The padding rules for each dimension"); + AddAttr("pad_value", "The value to be padded into tensor") + .SetDefault(0.0f); + } +}; + +class PadOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(pad, ops::PadOp, ops::PadOpMaker, pad_grad, ops::PadOpGrad); +REGISTER_OP_CPU_KERNEL(pad, ops::PadKernel); +REGISTER_OP_CPU_KERNEL(pad_grad, + ops::PadGradKernel); diff --git a/paddle/operators/pad_op.cu b/paddle/operators/pad_op.cu new file mode 100644 index 0000000000..555a7dba23 --- /dev/null +++ b/paddle/operators/pad_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/pad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(pad, ops::PadKernel); +REGISTER_OP_GPU_KERNEL(pad_grad, + ops::PadGradKernel); diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h new file mode 100644 index 0000000000..6a743bd31c --- /dev/null +++ b/paddle/operators/pad_op.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenTensor = framework::EigenTensor; + +template +class PadKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto paddings = + context.op_.GetAttr>>("paddings"); + T pad_value = context.op_.GetAttr("pad_value"); + + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + auto dims = X->dims(); + + // Eigen::TensorMap> X_tensor = EigenTensor::From(*X); + // Eigen::TensorMap> + // Out_tensor = EigenTensor::From(*Out); + EigenTensor::ConstType X_tensor = + EigenTensor::From(*X); + EigenTensor::Type Out_tensor = + EigenTensor::From(*Out); + Out_tensor = X_tensor.pad(paddings, pad_value); + } +}; + +template +class PadGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::vector> paddings = + context.op_.GetAttr>>("paddings"); + for (int i = 0; i < paddings.size(); ++i) { + paddings[0].first = -paddings[0].first; + paddings[1].second = -paddings[1].second; + } + auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto dims = dOut->dims(); + + auto* dX = ctx.Output(framework::GradVarName("X")); + dX->mutable_data(ctx.GetPlace()); + + EigenTensor::Type dX_tensor = + EigenTensor::From(*dX); + EigenTensor::ConstType dOut_tensor = + EigenTensor::From(*dOut); + dX_tensor = dOut_tensor.pad(paddings, 0); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index abb9c248ee..17ef1e8291 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -17,5 +17,6 @@ cc_library(paddle_pybind SHARED fill_zeros_like_op lookup_table_op scale_op - minus_op) + minus_op + pad_op) endif(WITH_PYTHON) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 8fa8be2cef..0176eb7a88 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -47,6 +47,7 @@ USE_OP(scale); USE_OP_ITSELF(identity); USE_OP(minus); USE_CPU_ONLY_OP(gather); +USE_OP(pad); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py new file mode 100644 index 0000000000..89ac7e7e1d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -0,0 +1,32 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class TestPadOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "pad" + self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } + self.attrs['paddings'] = ((0, 1), (2, 3)) + self.attrs['pad_value'] = 0 + self.outputs = { + 'Out': np.pad(self.inputs['X'], + self.attrs['paddings'], + mode='constant', + constant_value=0) + } + + +class PadGradOpTest(GradientChecker): + def test_pad(self): + op = Operator("pad", paddings=((0, 1), (2, 3)), pad_value=0) + inputs = {'X': np.random.random((16, 16)).astype("float32"), } + + self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5) + + +if __name__ == '__main__': + unittest.main() From 3eadb42d3d6e5c78b385104b47d5f564b20e3957 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 6 Sep 2017 10:58:23 +0800 Subject: [PATCH 02/13] Fix eigen error. --- paddle/operators/pad_op.cc | 12 +- paddle/operators/pad_op.h | 120 +++++++++++++----- .../paddle/v2/framework/tests/test_pad_op.py | 13 +- 3 files changed, 101 insertions(+), 44 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index f96d61669b..5dee8d0f5e 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -26,18 +26,18 @@ class PadOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Output("Out")->dims(); - auto paddings = GetAttr>>("paddings"); + auto paddings = GetAttr>>("paddings"); + std::vector dim1(dim0.size()); for (int i = 0; i < dim0.size(); ++i) { - dim1[i] = dim0[i] + paddings[i][0] + paddings[i][1]; + dim1[i] = dim0[i] + paddings[i].first + paddings[i].second; } - ctx.Output("Out")->Resize(dim1); + ctx.Output("Out")->Resize(paddle::framework::make_ddim(dim1)); } }; -class MulOpMaker : public framework::OpProtoAndCheckerMaker { +class PadOpMaker : public framework::OpProtoAndCheckerMaker { public: - MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of pad op"); AddOutput("Out", "The output of pad op"); diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index 6a743bd31c..9a0a064d75 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -28,52 +28,102 @@ template using EigenTensor = framework::EigenTensor; +template +void PadFunction(const framework::ExecutionContext& context) { + auto pads = context.op_.GetAttr>>("paddings"); + Eigen::array, D> paddings; + for (int i = 0; i < pads.size(); ++i) { + paddings[i] = pads[i]; + } + T pad_value = context.op_.GetAttr("pad_value"); + + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); + auto dims = X->dims(); + + auto X_tensor = EigenTensor::From(*X); + auto Out_tensor = EigenTensor::From(*Out); + auto place = context.GetEigenDevice(); + Out_tensor.device(place) = X_tensor.pad(paddings, pad_value); +} + template class PadKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto paddings = - context.op_.GetAttr>>("paddings"); - T pad_value = context.op_.GetAttr("pad_value"); - - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); - Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); - - // Eigen::TensorMap> X_tensor = EigenTensor::From(*X); - // Eigen::TensorMap> - // Out_tensor = EigenTensor::From(*Out); - EigenTensor::ConstType X_tensor = - EigenTensor::From(*X); - EigenTensor::Type Out_tensor = - EigenTensor::From(*Out); - Out_tensor = X_tensor.pad(paddings, pad_value); + int dim = context.Input("X")->dims().size(); + switch (dim) { + case 1: + PadFunction(context); + break; + case 2: + PadFunction(context); + break; + case 3: + PadFunction(context); + break; + case 4: + PadFunction(context); + break; + case 5: + PadFunction(context); + break; + case 6: + PadFunction(context); + break; + default: + LOG(ERROR) << "Only ranks up to 6 supported."; + } } }; +template +void PadGradFunction(const framework::ExecutionContext& context) { + auto pads = context.op_.GetAttr>>("paddings"); + Eigen::array, D> paddings; + for (int i = 0; i < pads.size(); ++i) { + paddings[0].first = -paddings[0].first; + paddings[1].second = -paddings[1].second; + } + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + dX->mutable_data(context.GetPlace()); + + auto dX_tensor = EigenTensor::From(*dX); + auto dOut_tensor = EigenTensor::From(*dOut); + auto place = context.GetEigenDevice(); + dX_tensor.device(place) = dOut_tensor.pad(paddings, 0); +} + template class PadGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector> paddings = - context.op_.GetAttr>>("paddings"); - for (int i = 0; i < paddings.size(); ++i) { - paddings[0].first = -paddings[0].first; - paddings[1].second = -paddings[1].second; + void Compute(const framework::ExecutionContext& context) const override { + size_t dim = + context.Input(framework::GradVarName("Out"))->dims().size(); + switch (dim) { + case 1: + PadGradFunction(context); + break; + case 2: + PadGradFunction(context); + break; + case 3: + PadGradFunction(context); + break; + case 4: + PadGradFunction(context); + break; + case 5: + PadGradFunction(context); + break; + case 6: + PadGradFunction(context); + break; + default: + LOG(ERROR) << "Only ranks up to 6 supported."; } - auto* dOut = ctx.Input(framework::GradVarName("Out")); - auto dims = dOut->dims(); - - auto* dX = ctx.Output(framework::GradVarName("X")); - dX->mutable_data(ctx.GetPlace()); - - EigenTensor::Type dX_tensor = - EigenTensor::From(*dX); - EigenTensor::ConstType dOut_tensor = - EigenTensor::From(*dOut); - dX_tensor = dOut_tensor.pad(paddings, 0); } }; diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py index 89ac7e7e1d..b862033d8c 100644 --- a/python/paddle/v2/framework/tests/test_pad_op.py +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -1,5 +1,6 @@ import unittest import numpy as np +from paddle.v2.framework.op import Operator from gradient_checker import GradientChecker, create_op from op_test_util import OpTestMeta @@ -10,19 +11,25 @@ class TestPadOp(unittest.TestCase): def setUp(self): self.type = "pad" self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } - self.attrs['paddings'] = ((0, 1), (2, 3)) + self.attrs = {} + self.attrs['paddings'] = [(0, 1), (2, 3)] self.attrs['pad_value'] = 0 self.outputs = { 'Out': np.pad(self.inputs['X'], self.attrs['paddings'], mode='constant', - constant_value=0) + constant_values=0) } class PadGradOpTest(GradientChecker): def test_pad(self): - op = Operator("pad", paddings=((0, 1), (2, 3)), pad_value=0) + op = Operator( + type="pad", + X="X", + Out="Out", + paddings=[(0, 1), (2, 3)], + pad_value=0) inputs = {'X': np.random.random((16, 16)).astype("float32"), } self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5) From 9f8e4981384d247e461290d7ceb642486663390d Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 6 Sep 2017 11:59:10 +0800 Subject: [PATCH 03/13] Fix some issues. --- paddle/operators/pad_op.cc | 3 +++ paddle/operators/pad_op.h | 10 +++++----- python/paddle/v2/framework/op.py | 2 +- python/paddle/v2/framework/tests/test_pad_op.py | 15 ++++++++++----- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 5dee8d0f5e..dac1c56bdd 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -27,6 +27,9 @@ class PadOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto dim0 = ctx.Input("X")->dims(); auto paddings = GetAttr>>("paddings"); + PADDLE_ENFORCE_EQ( + dim0.size(), paddings.size(), + "Paddings size should be equal to dimension size of input tensor."); std::vector dim1(dim0.size()); for (int i = 0; i < dim0.size(); ++i) { dim1[i] = dim0[i] + paddings[i].first + paddings[i].second; diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index 9a0a064d75..234019394c 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -14,8 +14,6 @@ #pragma once -#include "paddle/operators/math/math_function.h" - #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -30,12 +28,13 @@ using EigenTensor = framework::EigenTensor; template void PadFunction(const framework::ExecutionContext& context) { - auto pads = context.op_.GetAttr>>("paddings"); + auto pads = + context.op().GetAttr>>("paddings"); Eigen::array, D> paddings; for (int i = 0; i < pads.size(); ++i) { paddings[i] = pads[i]; } - T pad_value = context.op_.GetAttr("pad_value"); + T pad_value = context.op().GetAttr("pad_value"); auto* X = context.Input("X"); auto* Out = context.Output("Out"); @@ -80,7 +79,8 @@ class PadKernel : public framework::OpKernel { template void PadGradFunction(const framework::ExecutionContext& context) { - auto pads = context.op_.GetAttr>>("paddings"); + auto pads = + context.op().GetAttr>>("paddings"); Eigen::array, D> paddings; for (int i = 0; i < pads.size(); ++i) { paddings[0].first = -paddings[0].first; diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 0349407a85..359ccec814 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -96,7 +96,7 @@ class OpDescCreationMethod(object): new_attr.strings.extend(user_defined_attr) elif attr.type == framework_pb2.INT_PAIRS: for p in user_defined_attr: - pair = new_attr.pairs.add() + pair = new_attr.int_pairs.add() pair.first = p[0] pair.second = p[1] else: diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py index b862033d8c..10aeaa752f 100644 --- a/python/paddle/v2/framework/tests/test_pad_op.py +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -22,17 +22,22 @@ class TestPadOp(unittest.TestCase): } -class PadGradOpTest(GradientChecker): - def test_pad(self): - op = Operator( +class TestPadGradOp(GradientChecker): + def setUp(self): + self.op = Operator( type="pad", X="X", Out="Out", paddings=[(0, 1), (2, 3)], pad_value=0) - inputs = {'X': np.random.random((16, 16)).astype("float32"), } + self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } + + def test_normal(self): + self.check_grad( + self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5) - self.check_grad(op, inputs, set(["X"]), "Out", max_relative_error=0.5) + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) if __name__ == '__main__': From 7c30251d165ee9b3b9fd4fbd2440824ebcfbb5d7 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 6 Sep 2017 13:10:52 +0800 Subject: [PATCH 04/13] Fix padding attribute error. --- paddle/operators/pad_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index 234019394c..ed547d0a7f 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -83,8 +83,8 @@ void PadGradFunction(const framework::ExecutionContext& context) { context.op().GetAttr>>("paddings"); Eigen::array, D> paddings; for (int i = 0; i < pads.size(); ++i) { - paddings[0].first = -paddings[0].first; - paddings[1].second = -paddings[1].second; + paddings[i].first = -pads[i].first; + paddings[i].second = -pads[i].second; } auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); From 0910a9bac4d78347e06ecb90ad4e031bd6c7b0bd Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 8 Sep 2017 16:39:59 +0800 Subject: [PATCH 05/13] Refine pad op 1. Rename variables by Google style. 2. Add more test cases. 3. Add more detail and meaningful comments. 4. Change type of "padding" to vector --- paddle/operators/pad_op.cc | 44 ++++++++--- paddle/operators/pad_op.h | 49 ++++++------- .../paddle/v2/framework/tests/test_pad_op.py | 73 ++++++++++++++++--- 3 files changed, 122 insertions(+), 44 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index dac1c56bdd..94a6d20583 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -26,13 +26,13 @@ class PadOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { auto dim0 = ctx.Input("X")->dims(); - auto paddings = GetAttr>>("paddings"); + auto paddings = GetAttr>("paddings"); PADDLE_ENFORCE_EQ( - dim0.size(), paddings.size(), + dim0.size(), (int)(paddings.size() / 2), "Paddings size should be equal to dimension size of input tensor."); std::vector dim1(dim0.size()); for (int i = 0; i < dim0.size(); ++i) { - dim1[i] = dim0[i] + paddings[i].first + paddings[i].second; + dim1[i] = dim0[i] + paddings[i * 2] + paddings[i * 2 + 1]; } ctx.Output("Out")->Resize(paddle::framework::make_ddim(dim1)); } @@ -42,14 +42,40 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker { public: PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input of pad op"); - AddOutput("Out", "The output of pad op"); + AddInput("X", "The input of pad op."); + AddOutput("Out", "The output of pad op."); AddComment(R"DOC( -Pad Operator. +Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: + +Given: + +X = [[1, 2], + [3, 4]] + +and + +paddings = [(0,1),(1,2)] + +and + +pad_value = 0 + +then we get + +Out = [[0, 1, 2, 0, 0] + [0, 3, 4, 0, 0] + [0, 0, 0, 0, 0]] )DOC"); - AddAttr>>( - "paddings", "The padding rules for each dimension"); - AddAttr("pad_value", "The value to be padded into tensor") + AddAttr>( + "paddings", + "A pair list to describes padding rules for each dimension." + " For 2-D image tensor, paddings=[(0, 1), (2, 3)] means" + " padding 0 row to top, 1 row to bottom, 2 columns to left" + " and 3 columns to right.Paddings size should be equal to" + " dimension size of input tensor."); + AddAttr("pad_value", + "(float) default to 0; " + "The value to be padded into tensor. ") .SetDefault(0.0f); } }; diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index ed547d0a7f..dcf957b47e 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -28,23 +28,23 @@ using EigenTensor = framework::EigenTensor; template void PadFunction(const framework::ExecutionContext& context) { - auto pads = - context.op().GetAttr>>("paddings"); + auto pads = context.GetAttr>("paddings"); Eigen::array, D> paddings; - for (int i = 0; i < pads.size(); ++i) { - paddings[i] = pads[i]; + for (int i = 0; i < paddings.size(); ++i) { + paddings[i].first = pads[i * 2]; + paddings[i].second = pads[i * 2 + 1]; } - T pad_value = context.op().GetAttr("pad_value"); + T pad_value = context.GetAttr("pad_value"); - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); - Out->mutable_data(context.GetPlace()); - auto dims = X->dims(); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + auto dims = x->dims(); - auto X_tensor = EigenTensor::From(*X); - auto Out_tensor = EigenTensor::From(*Out); + auto x_tensor = EigenTensor::From(*x); + auto out_tensor = EigenTensor::From(*out); auto place = context.GetEigenDevice(); - Out_tensor.device(place) = X_tensor.pad(paddings, pad_value); + out_tensor.device(place) = x_tensor.pad(paddings, pad_value); } template @@ -72,28 +72,27 @@ class PadKernel : public framework::OpKernel { PadFunction(context); break; default: - LOG(ERROR) << "Only ranks up to 6 supported."; + PADDLE_THROW("Only ranks up to 6 supported."); } } }; template void PadGradFunction(const framework::ExecutionContext& context) { - auto pads = - context.op().GetAttr>>("paddings"); + auto pads = context.GetAttr>("paddings"); Eigen::array, D> paddings; - for (int i = 0; i < pads.size(); ++i) { - paddings[i].first = -pads[i].first; - paddings[i].second = -pads[i].second; + for (int i = 0; i < paddings.size(); ++i) { + paddings[i].first = -pads[i * 2]; + paddings[i].second = -pads[i * 2 + 1]; } - auto* dOut = context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); - dX->mutable_data(context.GetPlace()); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); + d_x->mutable_data(context.GetPlace()); - auto dX_tensor = EigenTensor::From(*dX); - auto dOut_tensor = EigenTensor::From(*dOut); + auto d_x_tensor = EigenTensor::From(*d_x); + auto d_out_tensor = EigenTensor::From(*d_out); auto place = context.GetEigenDevice(); - dX_tensor.device(place) = dOut_tensor.pad(paddings, 0); + d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0); } template @@ -122,7 +121,7 @@ class PadGradKernel : public framework::OpKernel { PadGradFunction(context); break; default: - LOG(ERROR) << "Only ranks up to 6 supported."; + PADDLE_THROW("Only ranks up to 6 supported."); } } }; diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py index 10aeaa752f..56b9c88f7d 100644 --- a/python/paddle/v2/framework/tests/test_pad_op.py +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -9,36 +9,89 @@ class TestPadOp(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): + self.initTestCase() self.type = "pad" - self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } + self.inputs = {'X': np.random.random(self.shape).astype("float32"), } self.attrs = {} - self.attrs['paddings'] = [(0, 1), (2, 3)] - self.attrs['pad_value'] = 0 + self.attrs['paddings'] = np.array(self.paddings).flatten() + self.attrs['pad_value'] = self.pad_value self.outputs = { 'Out': np.pad(self.inputs['X'], - self.attrs['paddings'], + self.paddings, mode='constant', - constant_values=0) + constant_values=self.pad_value) } + def initTestCase(self): + self.shape = (16, 16) + self.paddings = [(0, 1), (2, 3)] + self.pad_value = 0 + + +class TestCase1(TestPadOp): + def initTestCase(self): + self.shape = (2, 3, 4, 4) + self.paddings = [(0, 1), (2, 3), (2, 1), (1, 1)] + self.pad_value = 0.5 + + +class TestCase2(TestPadOp): + def initTestCase(self): + self.shape = (2, 2, 2) + self.paddings = [(0, 0), (0, 0), (1, 2)] + self.pad_value = 1 + + +class TestCase3(TestPadOp): + def initTestCase(self): + self.shape = (8) + self.paddings = [(0, 1)] + self.pad_value = 0.9 + class TestPadGradOp(GradientChecker): def setUp(self): + self.initTestCase() self.op = Operator( type="pad", X="X", Out="Out", - paddings=[(0, 1), (2, 3)], - pad_value=0) - self.inputs = {'X': np.random.random((16, 16)).astype("float32"), } + paddings=np.array(self.paddings).flatten(), + pad_value=self.pad_value) + self.inputs = {'X': np.random.random(self.shape).astype("float32"), } + + def initTestCase(self): + self.shape = (16, 16) + self.paddings = [(0, 1), (2, 3)] + self.pad_value = 0 def test_normal(self): - self.check_grad( - self.op, self.inputs, set(["X"]), "Out", max_relative_error=0.5) + self.check_grad(self.op, self.inputs, set(["X"]), "Out") def test_cpu_gpu_compare(self): self.compare_grad(self.op, self.inputs) +class TestiGradCase1(TestPadOp): + def initTestCase(self): + self.shape = (2, 3, 4, 4) + self.paddings = [(0, 1), (2, 3), (2, 1), (1, 1)] + self.pad_value = 0.5 + + +class TestGradCase2(TestPadOp): + def initTestCase(self): + self.shape = (2, 2, 2) + self.paddings = [(0, 0), (0, 0), (1, 2)] + self.pad_value = 1 + + +class TestGradCase3(TestPadOp): + def initTestCase(self): + self.shape = (8) + self.paddings = [(0, 1)] + self.pad_value = 0.9 + + if __name__ == '__main__': unittest.main() From d960cbdcf3f162c0da17fd04c8bc8eb770c9965b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 8 Sep 2017 16:48:39 +0800 Subject: [PATCH 06/13] Fix comment --- paddle/operators/pad_op.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 94a6d20583..6ea2a25f0b 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -27,9 +27,9 @@ class PadOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto dim0 = ctx.Input("X")->dims(); auto paddings = GetAttr>("paddings"); - PADDLE_ENFORCE_EQ( - dim0.size(), (int)(paddings.size() / 2), - "Paddings size should be equal to dimension size of input tensor."); + PADDLE_ENFORCE_EQ(dim0.size(), (int)(paddings.size() / 2), + "Size of paddings should be equal to 2 * dimension size " + "of input tensor."); std::vector dim1(dim0.size()); for (int i = 0; i < dim0.size(); ++i) { dim1[i] = dim0[i] + paddings[i * 2] + paddings[i * 2 + 1]; @@ -54,7 +54,7 @@ X = [[1, 2], and -paddings = [(0,1),(1,2)] +paddings = [0, 1, 1, 2] and @@ -68,11 +68,11 @@ Out = [[0, 1, 2, 0, 0] )DOC"); AddAttr>( "paddings", - "A pair list to describes padding rules for each dimension." - " For 2-D image tensor, paddings=[(0, 1), (2, 3)] means" + "A list to describes padding rules for each dimension." + " For 2-D image tensor, paddings=[0, 1, 2, 3] means" " padding 0 row to top, 1 row to bottom, 2 columns to left" - " and 3 columns to right.Paddings size should be equal to" - " dimension size of input tensor."); + " and 3 columns to right.Size of paddings should be equal to" + " 2 * dimension size of input tensor."); AddAttr("pad_value", "(float) default to 0; " "The value to be padded into tensor. ") From c7b347887dd6285dcb171499c17d705d424924ad Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 11 Sep 2017 11:46:04 +0800 Subject: [PATCH 07/13] Fix variable names and comments --- paddle/operators/pad_op.cc | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 6ea2a25f0b..894fe2cecf 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -25,16 +25,16 @@ class PadOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto paddings = GetAttr>("paddings"); - PADDLE_ENFORCE_EQ(dim0.size(), (int)(paddings.size() / 2), + auto x_dim = ctx.Input("X")->dims(); + auto paddings = Attr>("paddings"); + PADDLE_ENFORCE_EQ(x_dim.size() * 2, int(paddings.size()), "Size of paddings should be equal to 2 * dimension size " "of input tensor."); - std::vector dim1(dim0.size()); - for (int i = 0; i < dim0.size(); ++i) { - dim1[i] = dim0[i] + paddings[i * 2] + paddings[i * 2 + 1]; + std::vector out_dims(x_dim.size()); + for (int i = 0; i < x_dim.size(); ++i) { + out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; } - ctx.Output("Out")->Resize(paddle::framework::make_ddim(dim1)); + ctx.Output("Out")->Resize(framework::make_ddim(out_dims)); } }; @@ -42,8 +42,12 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker { public: PadOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input of pad op."); - AddOutput("Out", "The output of pad op."); + AddInput("X", + "The input of pad op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddOutput("Out", + "The output of pad op." + "A tensor with the same shape as X."); AddComment(R"DOC( Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: @@ -75,7 +79,7 @@ Out = [[0, 1, 2, 0, 0] " 2 * dimension size of input tensor."); AddAttr("pad_value", "(float) default to 0; " - "The value to be padded into tensor. ") + "The value to fill padded areas.") .SetDefault(0.0f); } }; From f31217fc2e535d0d1079a02895214c2c2f434809 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 11 Sep 2017 14:50:54 +0800 Subject: [PATCH 08/13] Fix issues --- paddle/operators/pad_op.cc | 5 +++-- paddle/operators/pad_op.h | 21 +++++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 894fe2cecf..ef678cf3d3 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -27,10 +27,10 @@ class PadOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dim = ctx.Input("X")->dims(); auto paddings = Attr>("paddings"); - PADDLE_ENFORCE_EQ(x_dim.size() * 2, int(paddings.size()), + PADDLE_ENFORCE_EQ(x_dim.size() * 2, int64_t(paddings.size()), "Size of paddings should be equal to 2 * dimension size " "of input tensor."); - std::vector out_dims(x_dim.size()); + std::vector out_dims(x_dim.size()); for (int i = 0; i < x_dim.size(); ++i) { out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1]; } @@ -95,6 +95,7 @@ class PadOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); + PADDLE_ENFORCE_NOT_NULL(x_grad, "Output(X@GRAD) should not be null"); x_grad->Resize(x_dims); } diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index dcf957b47e..53451f925a 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -28,18 +28,17 @@ using EigenTensor = framework::EigenTensor; template void PadFunction(const framework::ExecutionContext& context) { - auto pads = context.GetAttr>("paddings"); + auto pads = context.Attr>("paddings"); Eigen::array, D> paddings; for (int i = 0; i < paddings.size(); ++i) { paddings[i].first = pads[i * 2]; paddings[i].second = pads[i * 2 + 1]; } - T pad_value = context.GetAttr("pad_value"); + T pad_value = context.Attr("pad_value"); auto* x = context.Input("X"); auto* out = context.Output("Out"); out->mutable_data(context.GetPlace()); - auto dims = x->dims(); auto x_tensor = EigenTensor::From(*x); auto out_tensor = EigenTensor::From(*out); @@ -51,8 +50,8 @@ template class PadKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int dim = context.Input("X")->dims().size(); - switch (dim) { + int rank = context.Input("X")->dims().size(); + switch (rank) { case 1: PadFunction(context); break; @@ -72,14 +71,15 @@ class PadKernel : public framework::OpKernel { PadFunction(context); break; default: - PADDLE_THROW("Only ranks up to 6 supported."); + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); } } }; template void PadGradFunction(const framework::ExecutionContext& context) { - auto pads = context.GetAttr>("paddings"); + auto pads = context.Attr>("paddings"); Eigen::array, D> paddings; for (int i = 0; i < paddings.size(); ++i) { paddings[i].first = -pads[i * 2]; @@ -99,9 +99,9 @@ template class PadGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - size_t dim = + size_t rank = context.Input(framework::GradVarName("Out"))->dims().size(); - switch (dim) { + switch (rank) { case 1: PadGradFunction(context); break; @@ -121,7 +121,8 @@ class PadGradKernel : public framework::OpKernel { PadGradFunction(context); break; default: - PADDLE_THROW("Only ranks up to 6 supported."); + PADDLE_THROW( + "PadOp only support tensors with no more than 6 dimensions."); } } }; From 9c929a495980643672f66c882e76ca67e761954f Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 11 Sep 2017 15:19:19 +0800 Subject: [PATCH 09/13] Fix warning log --- paddle/operators/pad_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index 53451f925a..ca8832f26a 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -30,7 +30,7 @@ template void PadFunction(const framework::ExecutionContext& context) { auto pads = context.Attr>("paddings"); Eigen::array, D> paddings; - for (int i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < paddings.size(); ++i) { paddings[i].first = pads[i * 2]; paddings[i].second = pads[i * 2 + 1]; } @@ -81,7 +81,7 @@ template void PadGradFunction(const framework::ExecutionContext& context) { auto pads = context.Attr>("paddings"); Eigen::array, D> paddings; - for (int i = 0; i < paddings.size(); ++i) { + for (size_t i = 0; i < paddings.size(); ++i) { paddings[i].first = -pads[i * 2]; paddings[i].second = -pads[i * 2 + 1]; } From 355e35fecd2866a1894c304647f6875cf15f7571 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 12 Sep 2017 10:12:33 +0800 Subject: [PATCH 10/13] fix paddle enforce check --- paddle/operators/pad_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index ef678cf3d3..449463c830 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -95,7 +95,6 @@ class PadOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); - PADDLE_ENFORCE_NOT_NULL(x_grad, "Output(X@GRAD) should not be null"); x_grad->Resize(x_dims); } From 012453e28c4f27fd247d922671325011df1a6bb8 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 12 Sep 2017 10:43:19 +0800 Subject: [PATCH 11/13] fix NoInGrad bug --- paddle/operators/pad_op.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 449463c830..99f605c651 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -47,7 +47,8 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker { "The input should be a k-D tensor(k > 0 and k < 7)"); AddOutput("Out", "The output of pad op." - "A tensor with the same shape as X."); + "A tensor with the same shape as X.") + .NotInGradient(); AddComment(R"DOC( Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example: From 236a84c5050d419285cb7fbcc9c8f5bf923058ab Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 13 Sep 2017 11:09:01 +0800 Subject: [PATCH 12/13] Fix nullptr check --- paddle/operators/pad_op.cc | 5 +++-- paddle/operators/pad_op.h | 13 +++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 99f605c651..7e78b6ec13 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -96,8 +96,9 @@ class PadOpGrad : public framework::OperatorWithKernel { "Input(Out@GRAD) should not be null"); auto x_dims = ctx.Input("X")->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); - - x_grad->Resize(x_dims); + if (x_grad != nullptr) { + x_grad->Resize(x_dims); + } } }; diff --git a/paddle/operators/pad_op.h b/paddle/operators/pad_op.h index ca8832f26a..2cc3b945ae 100644 --- a/paddle/operators/pad_op.h +++ b/paddle/operators/pad_op.h @@ -87,12 +87,13 @@ void PadGradFunction(const framework::ExecutionContext& context) { } auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = context.Output(framework::GradVarName("X")); - d_x->mutable_data(context.GetPlace()); - - auto d_x_tensor = EigenTensor::From(*d_x); - auto d_out_tensor = EigenTensor::From(*d_out); - auto place = context.GetEigenDevice(); - d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0); + if (d_x != nullptr) { + d_x->mutable_data(context.GetPlace()); + auto d_x_tensor = EigenTensor::From(*d_x); + auto d_out_tensor = EigenTensor::From(*d_out); + auto place = context.GetEigenDevice(); + d_x_tensor.device(place) = d_out_tensor.pad(paddings, 0); + } } template From 23381dd16a6800a4e73a4bc36c0b1013a30e520d Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 13 Sep 2017 12:16:52 +0800 Subject: [PATCH 13/13] Update pad op unitest --- .../paddle/v2/framework/tests/test_pad_op.py | 60 +++---------------- 1 file changed, 9 insertions(+), 51 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_pad_op.py b/python/paddle/v2/framework/tests/test_pad_op.py index 56b9c88f7d..456b765e33 100644 --- a/python/paddle/v2/framework/tests/test_pad_op.py +++ b/python/paddle/v2/framework/tests/test_pad_op.py @@ -1,16 +1,12 @@ import unittest import numpy as np -from paddle.v2.framework.op import Operator -from gradient_checker import GradientChecker, create_op -from op_test_util import OpTestMeta +from op_test import OpTest -class TestPadOp(unittest.TestCase): - __metaclass__ = OpTestMeta - +class TestPadOp(OpTest): def setUp(self): self.initTestCase() - self.type = "pad" + self.op_type = "pad" self.inputs = {'X': np.random.random(self.shape).astype("float32"), } self.attrs = {} self.attrs['paddings'] = np.array(self.paddings).flatten() @@ -22,6 +18,12 @@ class TestPadOp(unittest.TestCase): constant_values=self.pad_value) } + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + def initTestCase(self): self.shape = (16, 16) self.paddings = [(0, 1), (2, 3)] @@ -49,49 +51,5 @@ class TestCase3(TestPadOp): self.pad_value = 0.9 -class TestPadGradOp(GradientChecker): - def setUp(self): - self.initTestCase() - self.op = Operator( - type="pad", - X="X", - Out="Out", - paddings=np.array(self.paddings).flatten(), - pad_value=self.pad_value) - self.inputs = {'X': np.random.random(self.shape).astype("float32"), } - - def initTestCase(self): - self.shape = (16, 16) - self.paddings = [(0, 1), (2, 3)] - self.pad_value = 0 - - def test_normal(self): - self.check_grad(self.op, self.inputs, set(["X"]), "Out") - - def test_cpu_gpu_compare(self): - self.compare_grad(self.op, self.inputs) - - -class TestiGradCase1(TestPadOp): - def initTestCase(self): - self.shape = (2, 3, 4, 4) - self.paddings = [(0, 1), (2, 3), (2, 1), (1, 1)] - self.pad_value = 0.5 - - -class TestGradCase2(TestPadOp): - def initTestCase(self): - self.shape = (2, 2, 2) - self.paddings = [(0, 0), (0, 0), (1, 2)] - self.pad_value = 1 - - -class TestGradCase3(TestPadOp): - def initTestCase(self): - self.shape = (8) - self.paddings = [(0, 1)] - self.pad_value = 0.9 - - if __name__ == '__main__': unittest.main()