!2699 add ScatterMax ScatterMin ScatterSub vm

Merge pull request !2699 from zhaozhenlong/op/scatter-max-vm
pull/2699/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 3a89f93cc5

@ -273,3 +273,6 @@ from .matrix_diag_part import _matrix_diag_part_tbe
from .matrix_set_diag import _matrix_set_diag_tbe
from .lrn import _lrn_tbe
from .lrn_grad import _lrn_grad_tbe
from .scatter_max import _scatter_max_tbe
from .scatter_min import _scatter_min_tbe
from .scatter_sub import _scatter_sub_tbe

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterMax op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_max_op_info = TBERegOp("ScatterMax") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_max.so") \
.compute_cost(10) \
.kernel_name("scatter_max") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_max_op_info)
def _scatter_max_tbe():
"""ScatterMax TBE register"""
return

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterMin op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_min_op_info = TBERegOp("ScatterMin") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_min.so") \
.compute_cost(10) \
.kernel_name("scatter_min") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(scatter_min_op_info)
def _scatter_min_tbe():
"""ScatterMin TBE register"""
return

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScatterSub op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
scatter_sub_op_info = TBERegOp("ScatterSub") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("scatter_sub.so") \
.compute_cost(10) \
.kernel_name("scatter_sub") \
.partial_flag(True) \
.attr("use_locking", "optional", "bool", "all") \
.input(0, "var", False, "required", "all") \
.input(1, "indices", False, "required", "all") \
.input(2, "updates", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \
.get_op_info()
@op_info_register(scatter_sub_op_info)
def _scatter_sub_tbe():
"""ScatterSub TBE register"""
return

@ -25,7 +25,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMax, ScatterMin, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split,
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
@ -214,8 +214,10 @@ __all__ = [
'BoundingBoxDecode',
'L2Normalize',
'ScatterAdd',
'ScatterSub',
'ScatterNd',
'ScatterMax',
'ScatterMin',
'ResizeNearestNeighbor',
'HistogramFixedWidth',
'Pad',

@ -2272,6 +2272,51 @@ class ScatterMax(PrimitiveWithInfer):
return x_dtype
class ScatterMin(PrimitiveWithInfer):
"""
Update the value of the input tensor through the min operation.
Using given values to update tensor value through the min operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do min operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the min operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mindspore.float32), name="input_x")
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> scatter_min = P.ScatterMin()
>>> output = scatter_min(input_x, indices, update)
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterMin"""
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class ScatterAdd(PrimitiveWithInfer):
"""
Update the value of the input tensor through the add operation.
@ -2316,6 +2361,50 @@ class ScatterAdd(PrimitiveWithInfer):
return x_dtype
class ScatterSub(PrimitiveWithInfer):
"""
Update the value of the input tensor through the sub operation.
Using given values to update tensor value through the sub operation, along with the input indices.
This operation outputs the `input_x` after the update is done, which makes it convenient to use the updated value.
Args:
use_locking (bool): Whether protect the assignment by a lock. Default: False.
Inputs:
- **input_x** (Parameter) - The target parameter.
- **indices** (Tensor) - The index to do sub operation whose data type should be mindspore.int32.
- **updates** (Tensor) - The tensor doing the sub operation with `input_x`,
the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`.
Outputs:
Parameter, the updated `input_x`.
Examples:
>>> input_x = Parameter(Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mindspore.float32), name="x")
>>> indices = Tensor(np.array([[0, 1]]), mindspore.int32)
>>> updates = Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mindspore.float32)
>>> scatter_sub = P.ScatterSub()
>>> output = scatter_sub(input_x, indices, updates)
[[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Init ScatterSub"""
validator.check_value_type('use_locking', use_locking, (bool,), self.name)
def infer_shape(self, x_shape, indices_shape, updates_shape):
_check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
return x_shape
def infer_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name)
args = {'x': x_dtype, 'updates': updates_dtype}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x_dtype
class SpaceToDepth(PrimitiveWithInfer):
r"""
Rearrange blocks of spatial data into depth.

@ -207,22 +207,35 @@ class HistogramSummaryNet(nn.Cell):
class ScatterMax(nn.Cell):
"""ScatterMax net definition"""
def __init__(self):
def __init__(self, dtype=np.float32, use_locking=False):
super(ScatterMax, self).__init__()
self.scatter_max = P.ScatterMax()
self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], np.float32)), name="ref")
self.scatter_max = P.ScatterMax(use_locking)
self.ref = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_max(self.ref, indices, updates)
return out
class ScatterMin(nn.Cell):
"""ScatterMin net definition"""
def __init__(self, dtype=np.float32, use_locking=False):
super(ScatterMin, self).__init__()
self.scatter_min = P.ScatterMin(use_locking)
self.ref = Parameter(Tensor(np.array([[-1.0, 2.0, 3.0], [-4.0, 1.0, 6.0]], dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_min(self.ref, indices, updates)
return out
class ScatterAdd(nn.Cell):
"""ScatterAdd net definition"""
def __init__(self, ref_shape, dtype=np.float32):
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
super(ScatterAdd, self).__init__()
self.scatter_add = P.ScatterAdd()
self.scatter_add = P.ScatterAdd(use_locking)
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
@ -230,6 +243,19 @@ class ScatterAdd(nn.Cell):
return out
class ScatterSub(nn.Cell):
"""ScatterSub net definition"""
def __init__(self, ref_shape, dtype=np.float32, use_locking=False):
super(ScatterSub, self).__init__()
self.scatter_sub = P.ScatterSub(use_locking)
self.ref = Parameter(Tensor(np.ones(ref_shape, dtype)), name="ref")
def construct(self, indices, updates):
out = self.scatter_sub(self.ref, indices, updates)
return out
class ApplyFtrlNet(nn.Cell):
def __init__(self):
super(ApplyFtrlNet, self).__init__()
@ -1741,11 +1767,61 @@ test_case_other_ops = [
Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.ones([2, 5], np.float32) * 99)),
'desc_bprop': [([3, 4, 5], {'dtype': np.float32})]}),
('ScatterMax', {
('ScatterMaxUseLocking', {
'block': ScatterMax(use_locking=True),
'desc_inputs': (Tensor(np.array([1, 0], np.int32)),
Tensor(np.array([[5.0, 5.0, 5.0], [4.0, 4.0, 4.0]], np.float32))),
'skip': ['backward']}),
('ScatterMax1d', {
'block': ScatterMax(),
'desc_inputs': (Tensor(np.array([1, 0], np.int32)),
Tensor(np.array([[5.0, 5.0, 5.0], [4.0, 4.0, 4.0]], np.float32))),
'skip': ['backward']}),
('ScatterMaxF32', {
'block': ScatterMax(),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float32) * 99)),
'skip': ['backward']}),
('ScatterMaxF16', {
'block': ScatterMax(np.float16),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float16) * 99)),
'skip': ['backward']}),
('ScatterMaxI32', {
'block': ScatterMax(np.int32),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.int32) * 99)),
'skip': ['backward']}),
('ScatterMinUseLocking', {
'block': ScatterMin(use_locking=True),
'desc_inputs': (Tensor(np.array([1, 0], np.int32)),
Tensor(np.ones([2, 3], np.float32))),
'skip': ['backward']}),
('ScatterMin1d', {
'block': ScatterMin(),
'desc_inputs': (Tensor(np.array([1, 0], np.int32)),
Tensor(np.ones([2, 3], np.float32))),
'skip': ['backward']}),
('ScatterMinF32', {
'block': ScatterMin(),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float32))),
'skip': ['backward']}),
('ScatterMinF16', {
'block': ScatterMin(np.float16),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.float16))),
'skip': ['backward']}),
('ScatterMinI32', {
'block': ScatterMin(np.int32),
'desc_inputs': (Tensor(np.array([[0, 0], [1, 1]], np.int32)),
Tensor(np.ones([2, 2, 3], np.int32))),
'skip': ['backward']}),
('ScatterAddUseLocking', {
'block': ScatterAdd((6,), use_locking=True),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float32))),
'skip': ['backward']}),
('ScatterAdd', {
'block': ScatterAdd((6,)),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
@ -1782,6 +1858,42 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.uint8))),
'skip': ['backward']}),
('ScatterSubUseLocking', {
'block': ScatterSub((6,), use_locking=True),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterSubScalar', {
'block': ScatterSub((6,)),
'desc_inputs': (Tensor(np.array([2], np.int32)),
Tensor(np.array([2.0], np.float32))),
'skip': ['backward']}),
('ScatterSub2d', {
'block': ScatterSub((3, 4)),
'desc_inputs': (Tensor(np.array([[0, 1], [1, 2]], np.int32)),
Tensor(np.array([[[1, 1, 1, 1], [2, 2, 2, 2]],
[[3, 3, 3, 3], [4, 4, 4, 4]]], np.float32))),
'skip': ['backward']}),
('ScatterSubF16', {
'block': ScatterSub((6,), np.float16),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2.0, 3.0, 4.0], np.float16))),
'skip': ['backward']}),
('ScatterSubI32', {
'block': ScatterSub((6,), np.int32),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int32))),
'skip': ['backward']}),
('ScatterSubI8', {
'block': ScatterSub((6,), np.int8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([2, 3, 4], np.int8))),
'skip': ['backward']}),
('ScatterSubU8', {
'block': ScatterSub((6,), np.uint8),
'desc_inputs': (Tensor(np.array([2, 0, 5], np.int32)),
Tensor(np.array([1, 1, 0], np.uint8))),
'skip': ['backward']}),
('SmoothL1Loss', {
'block': P.SmoothL1Loss(),
'desc_inputs': [[256, 4], [256, 4]],

Loading…
Cancel
Save