add vm for cosh and sinh

pull/1757/head
lihongkang 5 years ago
parent 00672a47b8
commit 436b309915

@ -793,6 +793,18 @@ def get_bprop_asinh(self):
return bprop
@bprop_getters.register(P.Sinh)
def get_bprop_sinh(self):
"""Grad definition for `Sinh` operation."""
cosh = P.Cosh()
def bprop(x, out, dout):
dx = cosh(x) * dout
return (dx,)
return bprop
@bprop_getters.register(P.Cos)
def get_bprop_cos(self):
"""Grad definition for `Cos` operation."""
@ -830,6 +842,18 @@ def get_bprop_acosh(self):
return bprop
@bprop_getters.register(P.Cosh)
def get_bprop_cosh(self):
"""Grad definition for `Cosh` operation."""
sinh = P.Sinh()
def bprop(x, out, dout):
dx = sinh(x) * dout
return (dx,)
return bprop
@bprop_getters.register(P.Abs)
def get_bprop_abs(self):
"""Grad definition for `Abs` operation."""

@ -227,3 +227,5 @@ from .asinh_grad import _asinh_grad_tbe
from .atan import _atan_tbe
from .atan_grad import _atan_grad_tbe
from .atanh import _atanh_tbe
from .cosh import _cosh_tbe
from .sinh import _sinh_tbe

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cosh op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
cosh_op_info = TBERegOp("Cosh") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("cosh.so") \
.compute_cost(10) \
.kernel_name("cosh") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(cosh_op_info)
def _cosh_tbe():
"""Cosh TBE register"""
return

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sinh op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
sinh_op_info = TBERegOp("Sinh") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sinh.so") \
.compute_cost(10) \
.kernel_name("sinh") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()
@op_info_register(sinh_op_info)
def _sinh_tbe():
"""Sinh TBE register"""
return

@ -40,7 +40,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor,
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
BitwiseXor,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
@ -50,7 +51,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh)
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh)
from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
@ -245,6 +247,7 @@ __all__ = [
'Asinh',
"PReLU",
"Cos",
"Cosh",
"ACos",
"Diag",
"DiagPart",
@ -253,6 +256,7 @@ __all__ = [
'AssignAdd',
'AssignSub',
"Sin",
"Sinh",
"Asin",
"LSTM",
"Abs",

@ -1359,6 +1359,35 @@ class Acosh(PrimitiveWithInfer):
return x_dtype
class Cosh(PrimitiveWithInfer):
"""
Computes hyperbolic cosine of input element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> cosh = P.Cosh()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = cosh(input_x)
[1.0289385 1.364684 1.048436 1.4228927]
"""
@prim_attr_register
def __init__(self):
"""init Cosh"""
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype
class Asinh(PrimitiveWithInfer):
"""
Compute inverse hyperbolic cosine of x element-wise.
@ -1376,7 +1405,6 @@ class Asinh(PrimitiveWithInfer):
[-2.3212, 1.1976, 1.8184, 5.2983]
"""
@prim_attr_register
def __init__(self):
"""init Asinh"""
@ -1389,6 +1417,35 @@ class Asinh(PrimitiveWithInfer):
return x_dtype
class Sinh(PrimitiveWithInfer):
"""
Computes hyperbolic sine of input element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Outputs:
Tensor, has the same shape as `input_x`.
Examples:
>>> sinh = P.Sinh()
>>> input_x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
>>> output = sinh(input_x)
[0.6604918 0.28367308 0.44337422 0.6604918]
"""
@prim_attr_register
def __init__(self):
"""init Sinh"""
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype
class _LogicBinaryOp(_BinaryOp):
"""
Define logic binary operators.

@ -128,7 +128,7 @@ class NetForFlattenComposed(nn.Cell):
self.flatten = P.Flatten()
def construct(self, x, y):
return self.flatten(x+x) + y
return self.flatten(x + x) + y
class ArgmaxNet(nn.Cell):
@ -281,6 +281,7 @@ class ApplyRMSNet(nn.Cell):
out = self.apply_rms(self.var, self.ms, self.moment, self.lr, grad, self.rho, self.momentum, self.epsilon)
return out
test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
@ -732,6 +733,14 @@ test_case_math_ops = [
'block': P.Atanh(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Cosh', {
'block': P.Cosh(),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
('Sinh', {
'block': P.Sinh(),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
]
test_case_nn_ops = [
@ -1301,7 +1310,7 @@ test_case_array_ops = [
'desc_inputs': [(Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}),
'desc_bprop': [[3, ]]}),
('Pack_0', {
'block': NetForPackInput(P.Pack()),
'desc_inputs': [[2, 2], [2, 2], [2, 2]],
@ -1464,7 +1473,7 @@ test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
test_exec_case = test_case
test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
'backward' not in x[1]['skip'], test_case)
'backward' not in x[1]['skip'], test_case)
@non_graph_engine

Loading…
Cancel
Save