From eb4e1a0d2b80e72ed4431956f57ea4fd5793df68 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Sat, 30 May 2020 10:37:14 +0800 Subject: [PATCH] ScatterAdd ScatterMax indices limited to int32 --- mindspore/ops/operations/array_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index eedbdb6500..19828d3871 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2222,7 +2222,7 @@ class ScatterMax(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target parameter. - - **indices** (Tensor) - The index to do max operation whose data type should be int. + - **indices** (Tensor) - The index to do max operation whose data type should be mindspore.int32. - **updates** (Tensor) - The tensor doing the maximum operation with `input_x`, the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. @@ -2249,7 +2249,7 @@ class ScatterMax(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) args = {"x": x_dtype, "updates": updates_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -2266,7 +2266,7 @@ class ScatterAdd(PrimitiveWithInfer): Inputs: - **input_x** (Parameter) - The target parameter. - - **indices** (Tensor) - The index to do add operation whose data type should be int. + - **indices** (Tensor) - The index to do add operation whose data type should be mindspore.int32. - **updates** (Tensor) - The tensor doing the add operation with `input_x`, the data type is same as `input_x`, the shape is `indices_shape + x_shape[1:]`. @@ -2292,7 +2292,7 @@ class ScatterAdd(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): - validator.check_tensor_type_same({'indices': indices_dtype}, mstype.int_type, self.name) + validator.check_tensor_type_same({'indices': indices_dtype}, (mstype.int32,), self.name) args = {'x': x_dtype, 'updates': updates_dtype} validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype