diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index c53367a20f..251946f6fd 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -110,6 +110,8 @@ const char kNameSigmoidCrossEntropyWithLogits[] = "SigmoidCrossEntropyWithLogits const char kNameSigmoidCrossEntropyWithLogitsGrad[] = "SigmoidCrossEntropyWithLogitsGrad"; const char kNameScatterNdD[] = "ScatterNd"; const char kNamePadD[] = "Pad"; +const char kNameMirrorPad[] = "MirrorPad"; +const char kNameMirrorPadGrad[] = "MirrorPadGrad"; const char kNameGatherNd[] = "GatherNd"; const char kNameArgmax[] = "Argmax"; const char kNameArgmin[] = "Argmin"; @@ -256,6 +258,8 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameSigmoidCrossEntropyWithLogitsGrad), ADPT_DESC(SigmoidCrossEntropyWithLogitsGrad)}, {string(kNameScatterNdD), ADPT_DESC(ScatterNdD)}, {string(kNamePadD), ADPT_DESC(PadD)}, + {string(kNameMirrorPad), ADPT_DESC(MirrorPad)}, + {string(kNameMirrorPadGrad), ADPT_DESC(MirrorPadGrad)}, {string(kNameGatherNd), ADPT_DESC(GatherNd)}, {string(kNameArgmax), ADPT_DESC(ArgMaxD)}, {string(kNameArgmin), ADPT_DESC(ArgMinD)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 419805c37f..7a7a696e2d 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -596,6 +596,16 @@ INPUT_MAP(PadD) = {{1, INPUT_DESC(x)}}; ATTR_MAP(PadD) = {{"paddings", ATTR_DESC(paddings, AnyTraits>>())}}; OUTPUT_MAP(PadD) = {{0, OUTPUT_DESC(y)}}; +// MirrorPad +INPUT_MAP(MirrorPad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; +ATTR_MAP(MirrorPad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(MirrorPad) = {{0, OUTPUT_DESC(y)}}; + +// MirrorPadGrad +INPUT_MAP(MirrorPadGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(paddings)}}; +ATTR_MAP(MirrorPadGrad) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; +OUTPUT_MAP(MirrorPadGrad) = {{0, OUTPUT_DESC(y)}}; + // GatherNd INPUT_MAP(GatherNd) = {{1, INPUT_DESC(x1)}, {2, INPUT_DESC(x2)}}; ATTR_MAP(GatherNd) = EMPTY_ATTR_MAP; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index e4d4101127..8f6dda9430 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -155,6 +155,10 @@ DECLARE_OP_USE_INPUT_ATTR(ScatterNdD) DECLARE_OP_USE_OUTPUT(ScatterNdD) DECLARE_OP_ADAPTER(PadD) DECLARE_OP_USE_OUTPUT(PadD) +DECLARE_OP_ADAPTER(MirrorPad) +DECLARE_OP_USE_OUTPUT(MirrorPad) +DECLARE_OP_ADAPTER(MirrorPadGrad) +DECLARE_OP_USE_OUTPUT(MirrorPadGrad) DECLARE_OP_ADAPTER(BoundingBoxEncode) DECLARE_OP_USE_OUTPUT(BoundingBoxEncode) DECLARE_OP_ADAPTER(BoundingBoxDecode) diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index aed6cb7776..f51eff2b31 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -22,7 +22,7 @@ from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm from .container import SequentialCell, CellList from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM -from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients +from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradients, Pad from .embedding import Embedding from .pooling import AvgPool2d, MaxPool2d @@ -34,5 +34,5 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'LSTM', 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'ImageGradients', 'Embedding', - 'AvgPool2d', 'MaxPool2d', + 'AvgPool2d', 'MaxPool2d', 'Pad', ] diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index de49685dac..5b36755d16 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -415,3 +415,72 @@ class ImageGradients(Cell): dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) dx = P.Concat(3)((dx, dx_last)) return dy, dx + + +class Pad(Cell): + """ + Pads the input tensor according to the paddings and mode. + + Args: + paddings (tuple): The shape of parameter `paddings` is (N, 2). N is the rank of input data. All elements of + paddings are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be + extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates how many sizes to + be extended behind of the `D` th dimension of the input tensor. + mode (string): Specifies padding mode. The optional values are "CONSTANT", "REFLECT", "SYMMETRIC". + Default: "CONSTANT". + + Inputs: + - ** input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, the tensor after padding. + + - If `mode` is "CONSTANT", it fill the edge with 0, regardless of the values of the `input_x`. + If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the + Outputs is [[0,0,0,0,0,0,0],[0,0,1,2,3,0,0],[0,0,4,5,6,0,0],[0,0,7,8,9,0,0],[0,0,0,0,0,0,0]]. + - If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in, + symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the + Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]]. + - If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied + according to the symmetry axis, except that it includes the symmetry axis. If the `input_x` + is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is + [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]]. + + Examples: + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> import mindspore.nn as nn + >>> import numpy as np + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.pad = nn.Pad(paddings=((1,1),(2,2)), mode="CONSTANT") + >>> def construct(self, x): + >>> return self.pad(x) + >>> x = np.random.random(size=(2, 3)).astype(np.float32) + >>> pad = Net() + >>> ms_output = pad(Tensor(x)) + """ + + def __init__(self, paddings, mode="CONSTANT"): + super(Pad, self).__init__() + self.mode = mode + self.paddings = paddings + validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) + if not isinstance(paddings, tuple): + raise TypeError('Paddings must be tuple type.') + for item in paddings: + if len(item) != 2: + raise ValueError('The shape of paddings must be (n, 2).') + if mode == "CONSTANT": + self.pad = P.Pad(self.paddings) + else: + self.paddings = Tensor(np.array(self.paddings)) + self.pad = P.MirrorPad(mode=mode) + + def construct(self, x): + if self.mode == "CONSTANT": + x = self.pad(x) + else: + x = self.pad(x, self.paddings) + return x diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 1b18d9f248..149dd6caec 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -470,6 +470,17 @@ def get_bprop_pad(self): return bprop +@bprop_getters.register(P.MirrorPad) +def get_bprop_mirror_pad(self): + """Grad definition for `MirrorPad` operation.""" + mirror_pad_grad = G.MirrorPadGrad(self.mode) + + def bprop(x, paddings, out, dout): + dx = mirror_pad_grad(dout, paddings, x) + return (dx, zeros_like(paddings)) + return bprop + + @bprop_getters.register(P.ROIAlign) def get_bprop_roi_align(self): """Grad definition for `ROIAlign` operation.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 8bfca77b38..40cbfc3381 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, LogSoftmax, MaxPool, AvgPool, Conv2DBackpropInput, - MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, + MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, SmoothL1Loss, Softmax, @@ -180,6 +180,7 @@ __all__ = [ 'ScatterNd', 'ResizeNearestNeighbor', 'Pad', + 'MirrorPad', 'GatherNd', 'ScatterNdUpdate', 'Floor', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index f0a9a2f658..d468fa7b19 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -947,6 +947,24 @@ class TanhGrad(PrimitiveWithInfer): return out +class MirrorPadGrad(PrimitiveWithInfer): + """Gradients of MirrorPad operation.""" + + @prim_attr_register + def __init__(self, mode="REFLECT"): + """init MirrorPad""" + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + self.mode = mode + + def __infer__(self, dout, paddings, x): + validator.check_subclass("dout", dout['dtype'], mstype.tensor) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) + validator.check_subclass("input_x", x['dtype'], mstype.tensor) + return {'shape': x['shape'], + 'dtype': dout['dtype'], + 'value': None} + + class RefToEmbed(Primitive): r""" Make a key from Ref. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 91f6d7ec01..1e3a4349ae 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2092,6 +2092,7 @@ class Pad(PrimitiveWithInfer): for item in paddings: if len(item) != 2: raise ValueError('The shape of paddings must be (n, 2).') + self.paddings = paddings def infer_shape(self, x): paddings = np.array(self.paddings) @@ -2104,9 +2105,78 @@ class Pad(PrimitiveWithInfer): return y_shape def infer_dtype(self, x): + validator.check_subclass("input_x", x, mstype.tensor) return x +class MirrorPad(PrimitiveWithInfer): + """ + Pads the input tensor according to the paddings and mode. + + Args: + mode (string): Specifies padding mode. The optional values are "REFLECT", "SYMMETRIC". + Default: "REFLECT". + + Inputs: + - **input_x** (Tensor) - The input tensor. + - **paddings** (Tensor) - The paddings tensor. The value of `paddings` is a matrix(list), + and its shape is (N, 2). N is the rank of input data. All elements of paddings + are int type. For `D` th dimension of input, paddings[D, 0] indicates how many sizes to be + extended ahead of the `D` th dimension of the input tensor, and paddings[D, 1] indicates + how many sizes to be extended behind of the `D` th dimension of the input tensor. + + Outputs: + Tensor, the tensor after padding. + + - If 'mode` is "REFLECT", it uses a way of symmetrical copying throught the axis of symmetry to fill in, + symmetry. If the `input_x` is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the + Outputs is [[6,5,4,5,6,5,4],[3,2,1,2,3,2,1],[6,5,4,5,6,5,4],[9,8,7,8,9,8,7],[6,5,4,5,6,5,4]]. + - If 'mode' is "SYMMETRIC", the filling method is similar to the "REFLECT". It is also copied + according to the symmetry axis, except that it includes the symmetry axis. If the `input_x` + is [[1,2,3],[4,5,6],[7,8,9]] and `paddings` is [[1,1],[2,2]], then the Outputs is + [[2,1,1,2,3,3,2],[2,1,1,2,3,3,2],[5,4,4,5,6,6,5],[8,7,7,8,9,9,8],[8,7,7,8,9,9,8]]. + + Examples: + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> import mindspore.nn as nn + >>> import numpy as np + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.pad = P.MirrorPad(mode="REFLECT") + >>> def construct(self, x, paddings): + >>> return self.pad(x, paddings) + >>> x = np.random.random(size=(2, 3)).astype(np.float32) + >>> paddings = Tensor([[1,1],[2,2]]) + >>> pad = Net() + >>> ms_output = pad(Tensor(x), paddings) + """ + + @prim_attr_register + def __init__(self, mode='REFLECT'): + """Init Pad""" + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + self.mode = mode + + def __infer__(self, input_x, paddings): + validator.check_subclass("input_x", input_x['dtype'], mstype.tensor) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) + x_shape = list(input_x['shape']) + paddings_value = paddings['value'].asnumpy() + paddings_size = paddings_value.size + validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ) + if not np.all(paddings_size >= 0): + raise ValueError('All elements of paddings must be >= 0.') + y_shape = () + for i in range(0, int(paddings_size / 2)): + y_shape += ((x_shape[i] + paddings_value[i, 0] + paddings_value[i, 1]),) + + return {'shape': y_shape, + 'dtype': input_x['dtype'], + 'value': None} + + class ROIAlign(PrimitiveWithInfer): """ Computes Region of Interest (RoI) Align operator. diff --git a/tests/ut/python/nn/test_nn_pad.py b/tests/ut/python/nn/test_nn_pad.py new file mode 100644 index 0000000000..a8b66bae5c --- /dev/null +++ b/tests/ut/python/nn/test_nn_pad.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test nn pad """ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.ops.composite import GradOperation +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context + + +class Net(nn.Cell): + def __init__(self, raw_paddings, mode): + super(Net, self).__init__() + self.pad = nn.Pad(raw_paddings, mode=mode) + + @ms_function + def construct(self, x): + return self.pad(x) + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + @ms_function + def construct(self, x, grads): + return self.grad(self.network)(x, grads) + + +def test_pad_train(): + mode = 'CONSTANT' + x = np.random.random(size=(2, 3)).astype(np.float32) + raw_paddings = ((1, 1), (2, 2)) + grads = np.random.random(size=(4, 7)).astype(np.float32) + grad = Grad(Net(raw_paddings, mode)) + output = grad(Tensor(x), Tensor(grads)) + print("=================output====================") + print(output) + + +def test_pad_infer(): + mode = 'CONSTANT' + x = np.random.random(size=(2, 3)).astype(np.float32) + raw_paddings = ((1, 1), (2, 2)) + net = Net(raw_paddings, mode) + output = net(Tensor(x)) + print("=================output====================") + print(output)