add vm support for Expm1

pull/1750/head
zhouneng 5 years ago
parent fd045e9115
commit e5419f7bd1

@ -422,6 +422,19 @@ def get_bprop_exp(self):
return bprop
@bprop_getters.register(P.Expm1)
def get_bprop_expm1(self):
"""Grad definition for `Expm1` operation."""
exp_ = P.Exp()
def bprop(x, out, dout):
g = exp_(x)
dx = g * dout
return (dx,)
return bprop
@bprop_getters.register(P.Minimum)
def get_bprop_minimum(self):
"""Grad definition for `Minimum` operation."""

@ -83,6 +83,7 @@ from .strided_slice_d import _strided_slice_d_tbe
from .strided_slice_grad_d import _strided_slice_grad_d_tbe
from .split_d import _split_d_tbe
from .exp import _exp_tbe
from .expm1 import _expm1_tbe
from .elu import _elu_tbe
from .elu_grad import _elu_grad_tbe
from .div import _div_tbe

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Expm1 op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
expm1_op_info = TBERegOp("Expm1") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("expm1.so") \
.compute_cost(10) \
.kernel_name("expm1") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(expm1_op_info)
def _expm1_tbe():
"""Expm1 TBE register"""
return

@ -42,7 +42,7 @@ from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Cos, Div, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
@ -89,6 +89,7 @@ __all__ = [
'Mul',
'Pow',
'Exp',
'Expm1',
'Rsqrt',
'Sqrt',
'Square',

@ -1004,6 +1004,36 @@ class Exp(PrimitiveWithInfer):
return x_type
class Expm1(PrimitiveWithInfer):
"""
Returns exponential then minus 1 of a tensor element-wise.
Inputs:
- **input_x** (Tensor) - The input tensor.
Outputs:
Tensor, has the same shape as the `input_x`.
Examples:
>>> input_x = Tensor(np.array([0.0, 1.0, 2.0, 4.0]), mindspore.float32)
>>> expm1 = P.Expm1()
>>> expm1(input_x)
[ 0., 1.71828183, 6.3890561 , 53.59815003]
"""
@prim_attr_register
def __init__(self):
"""init Exp"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.name)
return x_type
class Log(PrimitiveWithInfer):
"""
Returns the natural logarithm of a tensor element-wise.

@ -348,6 +348,10 @@ test_case_math_ops = [
'block': P.Exp(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Expm1', {
'block': P.Expm1(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Erf', {
'block': P.Erf(),
'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))],

Loading…
Cancel
Save