!11174 add_mish_mulnonan_selu_operations

From: @jiangzg001
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
pull/11174/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8170669909

@ -416,6 +416,58 @@ def get_bprop_dropout_do_mask(self):
return bprop
@bprop_getters.register(P.Mish)
def get_bprop_mish(self):
"""Grad definition for `Mish` operation."""
tanh = P.Tanh()
tanh_grad = SG.TanhGrad()
softplus = P.Softplus()
softplus_grad = G.SoftplusGrad()
def bprop(x, out, dout):
dx1 = tanh(softplus(x))
dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
dx = (dx1 * dout + dx2)
return (dx,)
return bprop
@bprop_getters.register(P.SeLU)
def get_bprop_selu(self):
"""Grad definition for `SeLU` operation."""
scale = 1.0507009873554804934193349852946
elu_grad = G.EluGrad()
def bprop(x, out, dout):
dx = elu_grad(dout, out) * scale
return (dx,)
return bprop
@bprop_getters.register(P.MulNoNan)
def get_bprop_mul_no_nan(self):
"""Grad definition for `MulNoNan` operation."""
mul_no_nan = P.MulNoNan()
reduce_sum = P.ReduceSum()
reshape = P.Reshape()
def bprop(x, y, out, dout):
x_shape = F.shape(x)
y_shape = F.shape(y)
dx = mul_no_nan(dout, y)
dy = mul_no_nan(x, dout)
broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
if broadcast_x != ():
dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
if broadcast_y != ():
dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
return dx, dy
return bprop
@bprop_getters.register(P.ReLU)
def get_bprop_relu(self):
"""Grad definition for `ReLU` operation."""

@ -356,3 +356,6 @@ from .lamb_apply_optimizer_assign import _lamb_apply_optimizer_assign_tbe
from .lamb_apply_weight_assign import _lamb_apply_weight_assign_tbe
from .nll_loss import _nll_loss_tbe
from .nll_loss_grad import _nll_loss_grad_tbe
from .mish import _mish_tbe
from .mul_no_nan import _mul_no_nan_tbe
from .selu import _selu_tbe

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mish op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
mish_op_info = TBERegOp("Mish") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("mish.so") \
.compute_cost(10) \
.kernel_name("mish") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(mish_op_info)
def _mish_tbe():
"""Mish TBE register"""
return

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MulNoNan op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
mul_no_nan_op_info = TBERegOp("MulNoNan") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("mul_no_nan.so") \
.compute_cost(10) \
.kernel_name("mul_no_nan") \
.partial_flag(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "x2", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("broadcast") \
.dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \
.dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \
.get_op_info()
@op_info_register(mul_no_nan_op_info)
def _mul_no_nan_tbe():
"""MulNoNan TBE register"""
return

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Selu op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
selu_op_info = TBERegOp("Selu") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("selu.so") \
.compute_cost(10) \
.kernel_name("selu") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(selu_op_info)
def _selu_tbe():
"""Selu TBE register"""
return

@ -48,7 +48,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, ReduceAny,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, MatMul, Maximum,
LogicalNot, LogicalOr, MatMul, Maximum, MulNoNan,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
@ -70,8 +70,8 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
LogSoftmax,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU,
SigmoidCrossEntropyWithLogits, NLLLoss,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SoftmaxCrossEntropyWithLogits, ROIAlign,
@ -194,6 +194,9 @@ __all__ = [
'ZerosLike',
'Select',
'Split',
'Mish',
'SeLU',
'MulNoNan',
'ReLU',
'ReLU6',
'Elu',

@ -2035,6 +2035,58 @@ class DivNoNan(_MathBinaryOp):
return None
class MulNoNan(_MathBinaryOp):
r"""
Computes x * y element-wise. if y is zero, No matter what x is, it will return 0.
Inputs of `input_x` and `input_y` comply with the implicit type conversion rules to make the data types consistent.
The inputs must be two tensors or one tensor and one scalar.
When the inputs are two tensors, the shapes of them could be broadcast.
When the inputs are one tensor and one scalar, the scalar could only be a constant.
Note:
The shapes of X and y should be same or can be broadcasting.
Inputs:
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
- **input_y** (Union[Tensor]) - The second input is a tensor whose data type is number.
Outputs:
Tensor, the shape is the same as the one after broadcasting,
and the data type is the one with higher precision or higher digits among the two inputs.
Supported Platforms:
``Ascend``
Raise:
TypeError: If x or y is a bool tensor.
Examples:
>>> x = Tensor(np.array([[-1.0, 6.0, np.inf], [np.nan, -7.0, 4.0]]), ms.float32)
>>> y = Tensor(np.array([[-1.0, 4.0, 0], [0, -3.0, 1.0]]), ms.float32)
>>> mul_no_nan = ops.MulNoNan()
>>> output = mul_no_nan(x, y)
>>> print(output)
[[ 1. 24. 0.]
[ 0. 21. 4.]]
"""
@prim_attr_register
def __init__(self):
"""Initialize _BinaryOp"""
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
with np.errstate(divide='ignore', invalid='ignore'):
out = np.multiply(x, y)
out[y == 0] = 0
return out
return None
class FloorDiv(_MathBinaryOp):
"""
Divides the first input tensor by the second input tensor element-wise and round down to the closest integer.
@ -4041,6 +4093,7 @@ class LinSpace(PrimitiveWithInfer):
'value': None}
return out
class MatrixInverse(PrimitiveWithInfer):
"""
Returns the inverse of the input matrix. If the matrix is irreversible, an error may be reported or an unknown

@ -329,6 +329,99 @@ class ReLU(PrimitiveWithCheck):
validator.check_tensor_dtype_valid('input_x', input_x, mstype.number_type, self.name)
class Mish(PrimitiveWithInfer):
r"""
Computes MISH of input tensors element-wise.
The function is shown as follows:
.. math::
\text{output} = x * \tan(\log(1 + \exp(\text{x})))
Inputs:
- **x** (Tensor) - The input tensor. Only support float16 and float32.
Outputs:
Tensor, with the same type and shape as the `x`.
Supported Platforms:
``Ascend``
Raise:
TypeError: If num_features data type not float16 and float32 Tensor.
Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> mish = ops.Mish()
>>> output = mish(input_x)
>>> print(output)
[[-3.034014e-01 3.997413e+00 -2.682209e-03]
[ 1.943959e+00 -3.357619e-02 8.999999e+00]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Mish"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid('x', x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype
class SeLU(PrimitiveWithInfer):
r"""
Computes SeLU (scaled exponential Linear Unit) of input tensors element-wise.
The activation function is defined as:
.. math::
E_{i} =
scale *
\begin{cases}
x, &\text{if } x \geq 0; \cr
\text{alpha} * (\exp(x_i) - 1), &\text{otherwise.}
\end{cases}
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, with the same type and shape as the `input_x`.
Supported Platforms:
``Ascend``
Raise:
TypeError: If num_features data type not int8, int32, float16 and float32 Tensor.
Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
>>> selu = ops.SeLU()
>>> output = selu(input_x)
>>> print(output)
[[-1.1113307 4.202804 -1.7575096]
[ 2.101402 -1.7462534 9.456309 ]]
"""
@prim_attr_register
def __init__(self):
"""Initialize SeLU"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
return x_dtype
class ReLU6(PrimitiveWithInfer):
r"""
Computes ReLU (Rectified Linear Unit) upper bounded by 6 of input tensors element-wise.

@ -320,6 +320,42 @@ class CountNonZero(nn.Cell):
return nonzero_num
class Mish(nn.Cell):
"""Mish net definition"""
def __init__(self):
super(Mish, self).__init__()
self.mish = P.Mish()
def construct(self, input_x):
out = self.mish(input_x)
return out
class SeLU(nn.Cell):
"""Selu net definition"""
def __init__(self):
super(SeLU, self).__init__()
self.selu = P.SeLU()
def construct(self, input_x):
out = self.selu(input_x)
return out
class MulNoNan(nn.Cell):
"""MulNoNan net definition"""
def __init__(self):
super(MulNoNan, self).__init__()
self.mul_no_nan = P.MulNoNan()
def construct(self, input_x, input_y):
out = self.mul_no_nan(input_x, input_y)
return out
class ScatterUpdate(nn.Cell):
"""ScatterUpdate net definition"""
@ -1315,6 +1351,19 @@ test_case_math_ops = [
Tensor(np.array([-6, -1, -2, -3]), mstype.float32),
Tensor(np.array([6, 1, 2, 3]), mstype.float32)],
'desc_bprop': [Tensor(np.random.rand(3, 16, 5, 4), mstype.float32)]}),
('Mish', {
'block': Mish(),
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
('SeLU', {
'block': SeLU(),
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
('MulNoNan', {
'block': MulNoNan(),
'desc_inputs': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32),
Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)],
'desc_bprop': [Tensor(np.random.rand(3, 6, 16, 16), mstype.float32)]}),
('Rank', {
'block': P.Rank(),
'desc_inputs': [[2, 3]],

Loading…
Cancel
Save