From e71599b5cac553df5142c2e698b9107088609b92 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 16 Jun 2020 16:44:34 +0800 Subject: [PATCH] vm for lin_space --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/nn/layer/math.py | 50 ++++++++++++++++++++++- mindspore/ops/_grad/grad_math_ops.py | 11 +++++ mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/lin_space.py | 40 ++++++++++++++++++ mindspore/ops/operations/_inner_ops.py | 39 ++++++++++++++++++ tests/ut/python/ops/test_ops.py | 8 ++++ 7 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/tbe/lin_space.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 3007280a14..9fd71e0e30 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -112,6 +112,7 @@ static std::map tbe_func_adapter_map = { {"square_sum_all", "square_sum_all"}, {"cum_sum", "cumsum_d"}, {"range", "range_d"}, + {"lin_space", "lin_space_d"}, {"inv_grad", "inv_grad"}, {"apply_rms_prop", "apply_rms_prop_d"}, {"cum_prod", "cumprod_d"}, diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 78652e5a40..1ecb20056e 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -20,8 +20,11 @@ from mindspore.common.tensor import Tensor from ..cell import Cell from ...common import dtype as mstype from ..._checkparam import Validator as validator +from ..._checkparam import Rel + + +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace'] -__all__ = ['ReduceLogSumExp', 'Range'] class ReduceLogSumExp(Cell): r""" @@ -125,3 +128,48 @@ class Range(Cell): def construct(self): range_out = self.range_x(self.input_tensor) return range_out + + +class LinSpace(Cell): + r""" + Generates values in an interval. And return the corresponding interpolation accroding to assist. + + Args: + - **start** (Union[int, float]) - The start of interval, With shape of 0-D. + - **stop** (Union[int, float]) - The end of interval, With shape of 0-D. + - **num** (int) - ticks number in the interval, the ticks include start and stop value. + With shape of 0-D. + + Outputs: + Tensor, With type same as `start`. The shape is 1-D with length of `num`. + + Examples: + >>> linspace = nn.LinSpace() + >>> start = Tensor(1, mindspore.float32) + >>> stop = Tensor(10, mindspore.float32) + >>> num = Tensor(5, mindspore.int32) + >>> output = linspace(start, stop, num) + [1, 3.25, 5.5, 7.75, 10] + """ + + def __init__(self, start, stop, num): + super(LinSpace, self).__init__() + validator.check_value_type("start", start, [int, float], self.cls_name) + validator.check_value_type("stop", stop, [int, float], self.cls_name) + validator.check_value_type("num", num, [int], self.cls_name) + validator.check_integer("num", num, 0, Rel.GT, self.cls_name) + + self.is_single = bool(num == 1) + self.lin_space = inner.LinSpace() + self.start = Tensor(start, mstype.float32) + self.stop = Tensor(stop, mstype.float32) + self.assist = Tensor(list(range(num)), mstype.float32) + self.num = Tensor(num, mstype.int32) + self.start_array = Tensor([start], mstype.float32) + + def construct(self): + if self.is_single: + return self.start_array + + lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num) + return lin_space_out diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 2a8a4fb03b..acc4bc0672 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -21,6 +21,7 @@ from mindspore.ops import _selected_grad_ops as SG from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G +from ..operations import _inner_ops as inner from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..functional import broadcast_gradient_args, reduced_shape, tuple_div from .grad_base import bprop_getters @@ -1049,3 +1050,13 @@ def get_bprop_inv(self): dx = inv_grad(out, dout) return (dx,) return bprop + + +@bprop_getters.register(inner.LinSpace) +def get_bprop_lin_space(self): + """Grad definition for `LinSpace` operation.""" + + def bprop(assist, start, stop, num, out, dout): + return zeros_like(assist), zeros_like(start), zeros_like(stop), zeros_like(num) + + return bprop diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index b68b29fcb8..0537d4b4f2 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -262,3 +262,4 @@ from .tensor_scatter_update import _tensor_scatter_update_tbe from .inplace_update import _inplace_update_tbe from .splitv import _split_v_tbe from .in_top_k import _in_top_k_tbe +from .lin_space import _lin_space_tbe diff --git a/mindspore/ops/_op_impl/tbe/lin_space.py b/mindspore/ops/_op_impl/tbe/lin_space.py new file mode 100644 index 0000000000..aed41e80d4 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/lin_space.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""LinSpace op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +lin_space_op_info = TBERegOp("LinSpace") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("lin_space.so") \ + .compute_cost(10) \ + .kernel_name("lin_space") \ + .partial_flag(True) \ + .op_pattern("broadcast") \ + .input(0, "assist", False, "required", "all") \ + .input(1, "start", False, "required", "all") \ + .input(2, "stop", False, "required", "all") \ + .input(3, "num", False, "required", "all") \ + .output(0, "output", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, + DataType.F32_Default,) \ + .get_op_info() + + +@op_info_register(lin_space_op_info) +def _lin_space_tbe(): + """LinSpace TBE register""" + return diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 699c34f6a3..e89a104623 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -328,3 +328,42 @@ class EmbeddingLookup(PrimitiveWithInfer): 'dtype': params['dtype'], 'value': None} return out + + +class LinSpace(PrimitiveWithInfer): + r""" + Generates values in an interval. And return the corresponding interpolation accroding to assist. + + Inputs: + - **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D. + - **start** (Tensor[float32]) - The start of interval, With shape of 0-D. + - **stop** (Tensor[float32]) - The end of interval, With shape of 0-D. + - **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value. + With shape of 0-D. + + Outputs: + Tensor, has the same shape as `assist`. + + Examples: + >>> linspace = P.LinSpace() + >>> assist = Tensor([5, 5.5], mindspore.float32) + >>> start = Tensor(1, mindspore.float32) + >>> stop = Tensor(10, mindspore.float32) + >>> num = Tensor(5, mindspore.int32) + >>> output = linspace(assist, start, stop, num) + [12.25, 13.375] + """ + + @prim_attr_register + def __init__(self): + pass + + def infer_shape(self, assist, start, stop, num): + return assist + + def infer_dtype(self, assist, start, stop, num): + args = {"num": num} + validator.check_tensor_type_same(args, (mstype.int32,), self.name) + args = {"assist": assist, "start": start, "stop": stop} + validator.check_tensor_type_same(args, (mstype.float32,), self.name) + return assist diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 5b5fd57aa9..5486a4319c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1599,6 +1599,14 @@ test_case_array_ops = [ 'desc_inputs': [Tensor(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(np.float32)), Tensor(np.array([1, 2, 3]).astype(np.int32))], 'desc_bprop': [[3, 3]]}), + ('LinSpace', { + 'block': inner.LinSpace(), + 'desc_inputs': [Tensor([5, 5.5], mstype.float32), + Tensor(1, mstype.float32), + Tensor(10, mstype.float32), + Tensor(5, mstype.int32)], + 'skip': ['backward'], + }), ] test_case_other_ops = [