diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index c16e16f96c..0c2618aded 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -273,3 +273,6 @@ from .matrix_diag_part import _matrix_diag_part_tbe from .matrix_set_diag import _matrix_set_diag_tbe from .lrn import _lrn_tbe from .lrn_grad import _lrn_grad_tbe +from .scatter_max import _scatter_max_tbe +from .scatter_min import _scatter_min_tbe +from .scatter_sub import _scatter_sub_tbe diff --git a/mindspore/ops/_op_impl/tbe/scatter_max.py b/mindspore/ops/_op_impl/tbe/scatter_max.py new file mode 100644 index 0000000000..ba85d63d35 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_max.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterMax op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_max_op_info = TBERegOp("ScatterMax") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_max.so") \ + .compute_cost(10) \ + .kernel_name("scatter_max") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(scatter_max_op_info) +def _scatter_max_tbe(): + """ScatterMax TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_min.py b/mindspore/ops/_op_impl/tbe/scatter_min.py new file mode 100644 index 0000000000..a4ab87a0d7 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_min.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterMin op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_min_op_info = TBERegOp("ScatterMin") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_min.so") \ + .compute_cost(10) \ + .kernel_name("scatter_min") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(scatter_min_op_info) +def _scatter_min_tbe(): + """ScatterMin TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_sub.py b/mindspore/ops/_op_impl/tbe/scatter_sub.py new file mode 100644 index 0000000000..2e5e01389c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_sub.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterSub op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_sub_op_info = TBERegOp("ScatterSub") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_sub.so") \ + .compute_cost(10) \ + .kernel_name("scatter_sub") \ + .partial_flag(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(scatter_sub_op_info) +def _scatter_sub_tbe(): + """ScatterSub TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 78db290784..552a980a0d 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -25,7 +25,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, - SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, + SameTypeShape, ScatterAdd, ScatterSub, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, Squeeze, StridedSlice, Tile, TensorScatterUpdate, @@ -214,8 +214,10 @@ __all__ = [ 'BoundingBoxDecode', 'L2Normalize', 'ScatterAdd', + 'ScatterSub', 'ScatterNd', 'ScatterMax', + 'ScatterMin', 'ResizeNearestNeighbor', 'HistogramFixedWidth', 'Pad', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index ced88adec6..dc387353af 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2272,6 +2272,51 @@ class ScatterMax(PrimitiveWithInfer): return x_dtype +class ScatterMin(PrimitiveWithInfer): + """ + Update the value of the input tensor through the min operation. + + Using given values to update tensor value through the min operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do min operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the min operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="input_x") + >>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32) + >>> update = Tensor(np.ones([2, 2, 3]), mindspore.float32) + >>> scatter_min = P.ScatterMin() + >>> output = scatter_min(input_x, indices, update) + [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]] + """ + + @prim_attr_register + def __init__(self, use_locking=False): + """Init ScatterMin""" + self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + validator.check_value_type('use_locking', use_locking, (bool,), self.name) + + def infer_shape(self, x_shape, indices_shape, updates_shape): + _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) + args = {"x": x_dtype, "updates": updates_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + return x_dtype + + class ScatterAdd(PrimitiveWithInfer): """ Update the value of the input tensor through the add operation. @@ -2316,6 +2361,50 @@ class ScatterAdd(PrimitiveWithInfer): return x_dtype +class ScatterSub(PrimitiveWithInfer): + """ + Update the value of the input tensor through the sub operation. + + Using given values to update tensor value through the sub operation, along with the input indices. + This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value. + + Args: + use_locking (bool): Whether protect the assignment by a lock. Default: False. + + Inputs: + - **input_x** (Parameter) - The target parameter. + - **indices** (Tensor) - The index to do sub operation whose data type should be mindspore.int32. + - **updates** (Tensor) - The tensor doing the sub operation with `input_x`, + the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. + + Outputs: + Parameter, the updated `input_x`. + + Examples: + >>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mindspore.float32), name="x") + >>> indices = Tensor(np.array([[0, 1]]), mindspore.int32) + >>> updates = Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32) + >>> scatter_sub = P.ScatterSub() + >>> output = scatter_sub(input_x, indices, updates) + [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]] + """ + + @prim_attr_register + def __init__(self, use_locking=False): + """Init ScatterSub""" + validator.check_value_type('use_locking', use_locking, (bool,), self.name) + + def infer_shape(self, x_shape, indices_shape, updates_shape): + _check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) + return x_shape + + def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): + validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) + args = {'x': x_dtype, 'updates': updates_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + return x_dtype + + class SpaceToDepth(PrimitiveWithInfer): r""" Rearrange blocks of spatial data into depth. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 7ffd7c8f6e..cd829e362e 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -207,22 +207,35 @@ class HistogramSummaryNet(nn.Cell): class ScatterMax(nn.Cell): """ScatterMax net definition""" - def __init__(self): + def __init__(self, dtype=np.float32, use_locking=False): super(ScatterMax, self).__init__() - self.scatter_max = P.ScatterMax() - self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], np.float32)), name="ref") + self.scatter_max = P.ScatterMax(use_locking) + self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype)), name="ref") def construct(self, indices, updates): out = self.scatter_max(self.ref, indices, updates) return out +class ScatterMin(nn.Cell): + """ScatterMin net definition""" + + def __init__(self, dtype=np.float32, use_locking=False): + super(ScatterMin, self).__init__() + self.scatter_min = P.ScatterMin(use_locking) + self.ref = Parameter(Tensor(np.array([[-1.0, 2.0, 3.0], [-4.0, 1.0, 6.0]], dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_min(self.ref, indices, updates) + return out + + class ScatterAdd(nn.Cell): """ScatterAdd net definition""" - def __init__(self, ref_shape, dtype=np.float32): + def __init__(self, ref_shape, dtype=np.float32, use_locking=False): super(ScatterAdd, self).__init__() - self.scatter_add = P.ScatterAdd() + self.scatter_add = P.ScatterAdd(use_locking) self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") def construct(self, indices, updates): @@ -230,6 +243,19 @@ class ScatterAdd(nn.Cell): return out +class ScatterSub(nn.Cell): + """ScatterSub net definition""" + + def __init__(self, ref_shape, dtype=np.float32, use_locking=False): + super(ScatterSub, self).__init__() + self.scatter_sub = P.ScatterSub(use_locking) + self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref") + + def construct(self, indices, updates): + out = self.scatter_sub(self.ref, indices, updates) + return out + + class ApplyFtrlNet(nn.Cell): def __init__(self): super(ApplyFtrlNet, self).__init__() @@ -1741,11 +1767,61 @@ test_case_other_ops = [ Tensor(np.array([[0, 1], [1, 2]], np.int32)), Tensor(np.ones([2, 5], np.float32) * 99)), 'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}), - ('ScatterMax', { + ('ScatterMaxUseLocking', { + 'block': ScatterMax(use_locking=True), + 'desc_inputs': (Tensor(np.array([1, 0], np.int32)), + Tensor(np.array([[5.0, 5.0, 5.0], [4.0, 4.0, 4.0]], np.float32))), + 'skip': ['backward']}), + ('ScatterMax1d', { + 'block': ScatterMax(), + 'desc_inputs': (Tensor(np.array([1, 0], np.int32)), + Tensor(np.array([[5.0, 5.0, 5.0], [4.0, 4.0, 4.0]], np.float32))), + 'skip': ['backward']}), + ('ScatterMaxF32', { 'block': ScatterMax(), 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), Tensor(np.ones([2, 2, 3], np.float32) * 99)), 'skip': ['backward']}), + ('ScatterMaxF16', { + 'block': ScatterMax(np.float16), + 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), + Tensor(np.ones([2, 2, 3], np.float16) * 99)), + 'skip': ['backward']}), + ('ScatterMaxI32', { + 'block': ScatterMax(np.int32), + 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), + Tensor(np.ones([2, 2, 3], np.int32) * 99)), + 'skip': ['backward']}), + ('ScatterMinUseLocking', { + 'block': ScatterMin(use_locking=True), + 'desc_inputs': (Tensor(np.array([1, 0], np.int32)), + Tensor(np.ones([2, 3], np.float32))), + 'skip': ['backward']}), + ('ScatterMin1d', { + 'block': ScatterMin(), + 'desc_inputs': (Tensor(np.array([1, 0], np.int32)), + Tensor(np.ones([2, 3], np.float32))), + 'skip': ['backward']}), + ('ScatterMinF32', { + 'block': ScatterMin(), + 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), + Tensor(np.ones([2, 2, 3], np.float32))), + 'skip': ['backward']}), + ('ScatterMinF16', { + 'block': ScatterMin(np.float16), + 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), + Tensor(np.ones([2, 2, 3], np.float16))), + 'skip': ['backward']}), + ('ScatterMinI32', { + 'block': ScatterMin(np.int32), + 'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)), + Tensor(np.ones([2, 2, 3], np.int32))), + 'skip': ['backward']}), + ('ScatterAddUseLocking', { + 'block': ScatterAdd((6,), use_locking=True), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float32))), + 'skip': ['backward']}), ('ScatterAdd', { 'block': ScatterAdd((6,)), 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), @@ -1782,6 +1858,42 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), Tensor(np.array([2, 3, 4], np.uint8))), 'skip': ['backward']}), + ('ScatterSubUseLocking', { + 'block': ScatterSub((6,), use_locking=True), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), + ('ScatterSubScalar', { + 'block': ScatterSub((6,)), + 'desc_inputs': (Tensor(np.array([2], np.int32)), + Tensor(np.array([2.0], np.float32))), + 'skip': ['backward']}), + ('ScatterSub2d', { + 'block': ScatterSub((3, 4)), + 'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)), + Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]], + [[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))), + 'skip': ['backward']}), + ('ScatterSubF16', { + 'block': ScatterSub((6,), np.float16), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2.0, 3.0, 4.0], np.float16))), + 'skip': ['backward']}), + ('ScatterSubI32', { + 'block': ScatterSub((6,), np.int32), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int32))), + 'skip': ['backward']}), + ('ScatterSubI8', { + 'block': ScatterSub((6,), np.int8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([2, 3, 4], np.int8))), + 'skip': ['backward']}), + ('ScatterSubU8', { + 'block': ScatterSub((6,), np.uint8), + 'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)), + Tensor(np.array([1, 1, 0], np.uint8))), + 'skip': ['backward']}), ('SmoothL1Loss', { 'block': P.SmoothL1Loss(), 'desc_inputs': [[256, 4], [256, 4]],