add dynamic shape for ScatterAdd/Update

pull/8236/head
TFbunny 5 years ago
parent 4b4ca1a188
commit 0bdf6c51a7

@ -69,6 +69,37 @@ class _ScatterOp(PrimitiveWithInfer):
return x_dtype
class _ScatterOp_Dynamic(PrimitiveWithCheck):
"""
Defines Scatter operators with dynamic shape
"""
__mindspore_signature__ = (
sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
sig.make_sig('updates', dtype=sig.sig_dtype.T)
)
def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name):
if indices_shape != [-1] and updates_shape and updates_shape != indices_shape + x_shape[1:]:
raise ValueError(f"For '{prim_name}', "
f"updates_shape = indices_shape + x_shape[1:], but got x_shape: {x_shape}, "
f"indices_shape: {indices_shape}, updates_shape: {updates_shape}.")
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize _ScatterOp_Dynamic"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def check_shape(self, x_shape, indices_shape, updates_shape):
self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name)
def check_dtype(self, x_dtype, indices_dtype, updates_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "updates": updates_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
class _ScatterNdOp(_ScatterOp):
"""
Defines _ScatterNd operators
@ -2723,7 +2754,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_dtype
class ScatterUpdate(_ScatterOp):
class ScatterUpdate(_ScatterOp_Dynamic):
"""
Updates tensor value by using input indices and value.
@ -2757,20 +2788,12 @@ class ScatterUpdate(_ScatterOp):
[[2.0, 1.2, 1.0],
[3.0, 1.2, 1.0]]
"""
@prim_attr_register
def __init__(self, use_locking=True):
"""Initialize ScatterUpdate"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
class ScatterNdUpdate(_ScatterNdOp):
"""
Updates tensor value by using input indices and value.
@ -2891,7 +2914,7 @@ class ScatterMin(_ScatterOp):
"""
class ScatterAdd(_ScatterOp):
class ScatterAdd(_ScatterOp_Dynamic):
"""
Updates the value of the input tensor through the add operation.
@ -2923,6 +2946,11 @@ class ScatterAdd(_ScatterOp):
>>> output = scatter_add(input_x, indices, updates)
[[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]
"""
@prim_attr_register
def __init__(self, use_locking=False):
"""Initialize ScatterAdd"""
validator.check_value_type('use_locking', use_locking, [bool], self.name)
self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y'])
class ScatterSub(_ScatterOp):

Loading…
Cancel
Save