diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 3fda554759..17ac8742f9 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -57,6 +57,7 @@ static std::map tbe_func_adapter_map = { {"strided_slice", "strided_slice_d"}, {"strided_slice_grad", "strided_slice_grad_d"}, {"transpose", "transpose_d"}, + {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, {"concat", "concat_d"}, {"slice", "slice_d"}, diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index c2f96e54c6..fb47c9fc2a 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -53,6 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); Register(kTopKOpName, {1}); + Register(kErfOpName, {1}); Register(kSparseApplyAdagradOpName, {2}); Register(kResizeNearestNeighborGrad, {1}); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f05eda69bf..6829a7e888 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -92,6 +92,7 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum"; constexpr auto kGreaterOpName = "Greater"; constexpr auto kSqrtOpName = "Sqrt"; constexpr auto kRsqrtOpName = "Rsqrt"; +constexpr auto kErfOpName = "Erf"; constexpr auto kRealDivOpName = "RealDiv"; constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR"; constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay"; diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 2d819718c8..c334050218 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -17,6 +17,7 @@ from functools import reduce +import numpy as np from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G @@ -333,6 +334,23 @@ def get_bprop_log(self): return bprop +@bprop_getters.register(P.Erf) +def get_bprop_erf(self): + """Grad definition for `Erf` operation.""" + exp = P.Exp() + square = P.Square() + sqrt = P.Sqrt() + cast = P.Cast() + dtype = P.DType() + + def bprop(x, out, dout): + half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x)) + x_square = square(x) + dx = dout * half_root_pi * exp(-x_square) + return (dx,) + return bprop + + @bprop_getters.register(P.Pow) def get_bprop_pow(self): """Grad definition for `Pow` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 37da184869..18ef92ca6e 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -139,6 +139,8 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe from .fused_mul_add import _fused_mul_add_tbe from .fused_mul_add_n import _fused_mul_add_n_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe +from .fill_d import _fill_d_op_tbe +from .erf import _erf_op_tbe from .depthwise_conv2d import _depthwise_conv2d_tbe from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe diff --git a/mindspore/ops/_op_impl/tbe/erf.py b/mindspore/ops/_op_impl/tbe/erf.py new file mode 100644 index 0000000000..2247197c4e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/erf.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Erf op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +erf_op_info = TBERegOp("Erf") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("erf.so") \ + .compute_cost(10) \ + .kernel_name("erf") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(erf_op_info) +def _erf_op_tbe(): + """Erf TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fill_d.py b/mindspore/ops/_op_impl/tbe/fill_d.py new file mode 100644 index 0000000000..97c6b73cf5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fill_d.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FillD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fill_d_op_info = TBERegOp("FillD") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fill_d.so") \ + .compute_cost(10) \ + .kernel_name("fill_d") \ + .partial_flag(True) \ + .attr("dims", "required", "listInt", "all") \ + .input(0, "value", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(fill_d_op_info) +def _fill_d_op_tbe(): + """FillD TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 2860690b91..80b03a04e1 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, - Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv, FloorMod, Acosh, + Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd, LogicalNot, LogicalOr, MatMul, Maximum, Minimum, Mul, Neg, NMSWithMask, NotEqual, @@ -139,6 +139,7 @@ __all__ = [ 'ReLU', 'ReLU6', 'Elu', + 'Erf', 'Sigmoid', 'HSwish', 'HSigmoid', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 33351a3ca1..8de4108435 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1007,6 +1007,36 @@ class Log(PrimitiveWithInfer): return x +class Erf(PrimitiveWithInfer): + r""" + Computes the Gauss error function of `input_x` element-wise. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the same shape and dtype as the `input_x`. + + Examples: + >>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32) + >>> erf = P.Erf() + >>> erf(input_x) + [-0.8427168, 0., 0.8427168, 0.99530876, 0.99997765] + """ + + @prim_attr_register + def __init__(self): + """init Erf""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) + return x_type + + class Minimum(_MathBinaryOp): """ Computes the element-wise minimum of input tensors. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1bd3a2e438..442c8bdec6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -250,6 +250,10 @@ test_case_math_ops = [ 'block': P.Exp(), 'desc_inputs': [[2, 3]], 'desc_bprop': [[2, 3]]}), + ('Erf', { + 'block': P.Erf(), + 'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))]}), ('Floor', { 'block': P.Floor(), 'desc_inputs': [[2, 512, 56, 56]],