From 991fb122f3946c9d49529b1389a1d70410e09eb8 Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Tue, 12 Jan 2021 10:11:22 +0800 Subject: [PATCH] add_mish_mulnonan_selu_operations --- mindspore/ops/_grad/grad_nn_ops.py | 54 +++++++++++++- mindspore/ops/_op_impl/tbe/__init__.py | 3 + mindspore/ops/_op_impl/tbe/mish.py | 37 +++++++++ mindspore/ops/_op_impl/tbe/mul_no_nan.py | 39 ++++++++++ mindspore/ops/_op_impl/tbe/selu.py | 39 ++++++++++ mindspore/ops/operations/__init__.py | 9 ++- mindspore/ops/operations/math_ops.py | 53 +++++++++++++ mindspore/ops/operations/nn_ops.py | 95 +++++++++++++++++++++++- tests/ut/python/ops/test_ops.py | 49 ++++++++++++ 9 files changed, 373 insertions(+), 5 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/mish.py create mode 100644 mindspore/ops/_op_impl/tbe/mul_no_nan.py create mode 100644 mindspore/ops/_op_impl/tbe/selu.py diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index a61499f9e7..fdee5fb995 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -280,7 +280,7 @@ def _get_mean_matrix(x_shape, ksize, stride, pad_mode, x_dtype): the value of element which is padded is 0, else are 1. For each element of output, it is mapped for slide window: `[h*h_stride : h*h_stride + h_ksize, w*w_stride : w*w_stride + w_ksize]` of `assist_input_matrix`, so the sum of slide window is the - number of input that assosiate with output element. + number of input that associate with output element. """ n_input, c_input, h_input, w_input = x_shape @@ -416,6 +416,58 @@ def get_bprop_dropout_do_mask(self): return bprop +@bprop_getters.register(P.Mish) +def get_bprop_mish(self): + """Grad definition for `Mish` operation.""" + tanh = P.Tanh() + tanh_grad = SG.TanhGrad() + softplus = P.Softplus() + softplus_grad = G.SoftplusGrad() + + def bprop(x, out, dout): + dx1 = tanh(softplus(x)) + dx2 = softplus_grad(tanh_grad(dx1, x * dout), x) + dx = (dx1 * dout + dx2) + return (dx,) + + return bprop + + +@bprop_getters.register(P.SeLU) +def get_bprop_selu(self): + """Grad definition for `SeLU` operation.""" + scale = 1.0507009873554804934193349852946 + elu_grad = G.EluGrad() + + def bprop(x, out, dout): + dx = elu_grad(dout, out) * scale + return (dx,) + + return bprop + + +@bprop_getters.register(P.MulNoNan) +def get_bprop_mul_no_nan(self): + """Grad definition for `MulNoNan` operation.""" + mul_no_nan = P.MulNoNan() + reduce_sum = P.ReduceSum() + reshape = P.Reshape() + + def bprop(x, y, out, dout): + x_shape = F.shape(x) + y_shape = F.shape(y) + dx = mul_no_nan(dout, y) + dy = mul_no_nan(x, dout) + broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape) + if broadcast_x != (): + dx = reshape(reduce_sum(dx, broadcast_x), x_shape) + if broadcast_y != (): + dy = reshape(reduce_sum(dy, broadcast_y), y_shape) + return dx, dy + + return bprop + + @bprop_getters.register(P.ReLU) def get_bprop_relu(self): """Grad definition for `ReLU` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index f4a1811073..ecf1a29a2a 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -355,3 +355,6 @@ from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe from .nll_loss import _nll_loss_tbe from .nll_loss_grad import _nll_loss_grad_tbe +from .mish import _mish_tbe +from .mul_no_nan import _mul_no_nan_tbe +from .selu import _selu_tbe diff --git a/mindspore/ops/_op_impl/tbe/mish.py b/mindspore/ops/_op_impl/tbe/mish.py new file mode 100644 index 0000000000..42473a0b01 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/mish.py @@ -0,0 +1,37 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Mish op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +mish_op_info = TBERegOp("Mish") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("mish.so") \ + .compute_cost(10) \ + .kernel_name("mish") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(mish_op_info) +def _mish_tbe(): + """Mish TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/mul_no_nan.py b/mindspore/ops/_op_impl/tbe/mul_no_nan.py new file mode 100644 index 0000000000..632c291d19 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/mul_no_nan.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MulNoNan op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +mul_no_nan_op_info = TBERegOp("MulNoNan") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("mul_no_nan.so") \ + .compute_cost(10) \ + .kernel_name("mul_no_nan") \ + .partial_flag(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .get_op_info() + + +@op_info_register(mul_no_nan_op_info) +def _mul_no_nan_tbe(): + """MulNoNan TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/selu.py b/mindspore/ops/_op_impl/tbe/selu.py new file mode 100644 index 0000000000..64d33b3c5a --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/selu.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Selu op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +selu_op_info = TBERegOp("Selu") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("selu.so") \ + .compute_cost(10) \ + .kernel_name("selu") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(selu_op_info) +def _selu_tbe(): + """Selu TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 4c731a6b0f..9439e44417 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny, Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod, - LogicalNot, LogicalOr, MatMul, Maximum, + LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan, Minimum, Mul, Neg, NMSWithMask, NotEqual, NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, @@ -70,8 +70,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam LogSoftmax, MaxPool, DataFormatDimMap, AvgPool, Conv2DBackpropInput, ComputeAccidentalHits, - MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, - ResizeBilinear, Sigmoid, + MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, + ResizeBilinear, Sigmoid, SeLU, SigmoidCrossEntropyWithLogits, NLLLoss, SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2, SoftmaxCrossEntropyWithLogits, ROIAlign, @@ -194,6 +194,9 @@ __all__ = [ 'ZerosLike', 'Select', 'Split', + 'Mish', + 'SeLU', + 'MulNoNan', 'ReLU', 'ReLU6', 'Elu', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 62962c716d..fd6ad3d118 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2035,6 +2035,58 @@ class DivNoNan(_MathBinaryOp): return None +class MulNoNan(_MathBinaryOp): + r""" + Computes x * y element-wise. if y is zero, No matter what x is, it will return 0. + + Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent. + The inputs must be two tensors or one tensor and one scalar. + When the inputs are two tensors, the shapes of them could be broadcast. + When the inputs are one tensor and one scalar, the scalar could only be a constant. + + Note: + The shapes of X and y should be same or can be broadcasting. + + Inputs: + - **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number. + - **input_y** (Union[Tensor]) - The second input is a tensor whose data type is number. + + Outputs: + Tensor, the shape is the same as the one after broadcasting, + and the data type is the one with higher precision or higher digits among the two inputs. + + Supported Platforms: + ``Ascend`` + + Raise: + TypeError: If x or y is a bool tensor. + + Examples: + >>> x = Tensor(np.array([[-1.0, 6.0, np.inf], [np.nan, -7.0, 4.0]]), ms.float32) + >>> y = Tensor(np.array([[-1.0, 4.0, 0], [0, -3.0, 1.0]]), ms.float32) + >>> mul_no_nan = ops.MulNoNan() + >>> output = mul_no_nan(x, y) + >>> print(output) + [[ 1. 24. 0.] + [ 0. 21. 4.]] + """ + + @prim_attr_register + def __init__(self): + """Initialize _BinaryOp""" + self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) + + def infer_value(self, x, y): + if x is not None and y is not None: + x = x.asnumpy() + y = y.asnumpy() + with np.errstate(divide='ignore', invalid='ignore'): + out = np.multiply(x, y) + out[y == 0] = 0 + return out + return None + + class FloorDiv(_MathBinaryOp): """ Divides the first input tensor by the second input tensor element-wise and round down to the closest integer. @@ -4041,6 +4093,7 @@ class LinSpace(PrimitiveWithInfer): 'value': None} return out + class MatrixInverse(PrimitiveWithInfer): """ Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index b474a1ddd4..21321338a7 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -329,6 +329,99 @@ class ReLU(PrimitiveWithCheck): validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name) +class Mish(PrimitiveWithInfer): + r""" + Computes MISH of input tensors element-wise. + + The function is shown as follows: + + .. math:: + + \text{output} = x * \tan(\log(1 + \exp(\text{x}))) + + Inputs: + - **x** (Tensor) - The input tensor. Only support float16 and float32. + + Outputs: + Tensor, with the same type and shape as the `x`. + + Supported Platforms: + ``Ascend`` + + Raise: + TypeError: If num_features data type not float16 and float32 Tensor. + + Examples: + >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) + >>> mish = ops.Mish() + >>> output = mish(input_x) + >>> print(output) + [[-3.034014e-01 3.997413e+00 -2.682209e-03] + [ 1.943959e+00 -3.357619e-02 8.999999e+00]] + """ + + @prim_attr_register + def __init__(self): + """Initialize Mish""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name) + return x_dtype + + +class SeLU(PrimitiveWithInfer): + r""" + Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise. + + The activation function is defined as: + + .. math:: + E_{i} = + scale * + \begin{cases} + x, &\text{if } x \geq 0; \cr + \text{alpha} * (\exp(x_i) - 1), &\text{otherwise.} + \end{cases} + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, with the same type and shape as the `input_x`. + + Supported Platforms: + ``Ascend`` + + Raise: + TypeError: If num_features data type not int8, int32, float16 and float32 Tensor. + + Examples: + >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) + >>> selu = ops.SeLU() + >>> output = selu(input_x) + >>> print(output) + [[-1.1113307 4.202804 -1.7575096] + [ 2.101402 -1.7462534 9.456309 ]] + """ + + @prim_attr_register + def __init__(self): + """Initialize SeLU""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name) + return x_dtype + + class ReLU6(PrimitiveWithInfer): r""" Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise. @@ -1338,7 +1431,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): :math:`\text{in_channels} * \text{channel_multiplier}` channels. Args: - channel_multiplier (int): The multipiler for the original output convolution. Its value must be greater than 0. + channel_multiplier (int): The multiplier for the original output convolution. Its value must be greater than 0. kernel_size (Union[int, tuple[int]]): The size of the convolution kernel. mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution , 2 deconvolution, 3 depthwise convolution. Default: 3. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 9b75d870b2..09f9c953c3 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -320,6 +320,42 @@ class CountNonZero(nn.Cell): return nonzero_num +class Mish(nn.Cell): + """Mish net definition""" + + def __init__(self): + super(Mish, self).__init__() + self.mish = P.Mish() + + def construct(self, input_x): + out = self.mish(input_x) + return out + + +class SeLU(nn.Cell): + """Selu net definition""" + + def __init__(self): + super(SeLU, self).__init__() + self.selu = P.SeLU() + + def construct(self, input_x): + out = self.selu(input_x) + return out + + +class MulNoNan(nn.Cell): + """MulNoNan net definition""" + + def __init__(self): + super(MulNoNan, self).__init__() + self.mul_no_nan = P.MulNoNan() + + def construct(self, input_x, input_y): + out = self.mul_no_nan(input_x, input_y) + return out + + class ScatterUpdate(nn.Cell): """ScatterUpdate net definition""" @@ -1315,6 +1351,19 @@ test_case_math_ops = [ Tensor(np.array([-6, -1, -2, -3]), mstype.float32), Tensor(np.array([6, 1, 2, 3]), mstype.float32)], 'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}), + ('Mish', { + 'block': Mish(), + 'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)], + 'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}), + ('SeLU', { + 'block': SeLU(), + 'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)], + 'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}), + ('MulNoNan', { + 'block': MulNoNan(), + 'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32), + Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)], + 'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}), ('Rank', { 'block': P.Rank(), 'desc_inputs': [[2, 3]],