From f1e92d0ea727a07344bf2af7be34e2088416ed28 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Fri, 20 Nov 2020 15:56:04 +0800 Subject: [PATCH] Change Ones/zeros --- mindspore/ops/operations/__init__.py | 2 ++ mindspore/ops/operations/array_ops.py | 48 +++++++++++++++++---------- tests/ut/python/ops/test_array_ops.py | 14 ++++++++ 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 7d7f1a2777..22f4c77c2f 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -180,6 +180,8 @@ __all__ = [ 'Invert', 'TruncatedNormal', 'Fill', + 'Ones', + 'Zeros', 'OnesLike', 'ZerosLike', 'Select', diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a7c6b4ef50..bb7a591915 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1121,22 +1121,24 @@ class Fill(PrimitiveWithInfer): class Ones(PrimitiveWithInfer): - """ + r""" Creates a tensor filled with value ones. Creates a tensor with shape described by the first argument and fills it with value ones in type of the second argument. Inputs: - - **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed. + - **shape** (Union[tuple[int], int]) - The specified shape of output tensor. + Only constant positive int is allowed. - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. Outputs: - Tensor, has the same type and shape as input value. + Tensor, has the same type and shape as input shape value. Examples: + >>> from mindspore.ops import operations as P >>> ones = P.Ones() - >>> output = Ones((2, 2), mindspore.float32) + >>> output = ones((2, 2), mindspore.float32) >>> print(output) [[1.0, 1.0], [1.0, 1.0]] @@ -1147,40 +1149,46 @@ class Ones(PrimitiveWithInfer): """Initialize Fill""" def __infer__(self, dims, dtype): - validator.check_value_type("shape", dims['value'], [tuple], self.name) - for i, item in enumerate(dims['value']): - validator.check_positive_int(item, f'dims[{i}]', self.name) + if isinstance(dims['value'], int): + shape = (dims['value'],) + else: + shape = dims['value'] + validator.check_value_type("shape", shape, [tuple], self.name) + for i, item in enumerate(shape): + validator.check_non_negative_int(item, shape[i], self.name) valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) x_nptype = mstype.dtype_to_nptype(dtype['value']) - ret = np.ones(dims['value'], x_nptype) + ret = np.ones(shape, x_nptype) out = { 'value': Tensor(ret), - 'shape': dims['value'], + 'shape': shape, 'dtype': x_nptype, } return out class Zeros(PrimitiveWithInfer): - """ + r""" Creates a tensor filled with value zeros. Creates a tensor with shape described by the first argument and fills it with value zeros in type of the second argument. Inputs: - - **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed. + - **shape** (Union[tuple[int], int]) - The specified shape of output tensor. + Only constant positive int is allowed. - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. Outputs: - Tensor, has the same type and shape as input value. + Tensor, has the same type and shape as input shape value. Examples: + >>> from mindspore.ops import operations as P >>> zeros = P.Zeros() - >>> output = Zeros((2, 2), mindspore.float32) + >>> output = zeros((2, 2), mindspore.float32) >>> print(output) [[0.0, 0.0], [0.0, 0.0]] @@ -1192,18 +1200,22 @@ class Zeros(PrimitiveWithInfer): """Initialize Fill""" def __infer__(self, dims, dtype): - validator.check_value_type("shape", dims['value'], [tuple], self.name) - for i, item in enumerate(dims['value']): - validator.check_positive_int(item, f'dims[{i}]', self.name) + if isinstance(dims['value'], int): + shape = (dims['value'],) + else: + shape = dims['value'] + validator.check_value_type("shape", shape, [tuple], self.name) + for i, item in enumerate(shape): + validator.check_non_negative_int(item, shape[i], self.name) valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8, mstype.uint32, mstype.uint64, mstype.float16, mstype.float32, mstype.float64] validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) x_nptype = mstype.dtype_to_nptype(dtype['value']) - ret = np.zeros(dims['value'], x_nptype) + ret = np.zeros(shape, x_nptype) out = { 'value': Tensor(ret), - 'shape': dims['value'], + 'shape': shape, 'dtype': x_nptype, } return out diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index febf4c6b99..3992508265 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -59,6 +59,13 @@ def test_ones(): assert np.sum(output.asnumpy()) == 6 +def test_ones_1(): + ones = P.Ones() + output = ones(2, mstype.int32) + assert output.asnumpy().shape == (2,) + assert np.sum(output.asnumpy()) == 2 + + def test_zeros(): zeros = P.Zeros() output = zeros((2, 3), mstype.int32) @@ -66,6 +73,13 @@ def test_zeros(): assert np.sum(output.asnumpy()) == 0 +def test_zeros_1(): + zeros = P.Zeros() + output = zeros(2, mstype.int32) + assert output.asnumpy().shape == (2,) + assert np.sum(output.asnumpy()) == 0 + + @non_graph_engine def test_reshape(): input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))