Change Ones/zeros

pull/8842/head
l00591931 4 years ago
parent f052ce8ba2
commit f1e92d0ea7

@ -180,6 +180,8 @@ __all__ = [
'Invert', 'Invert',
'TruncatedNormal', 'TruncatedNormal',
'Fill', 'Fill',
'Ones',
'Zeros',
'OnesLike', 'OnesLike',
'ZerosLike', 'ZerosLike',
'Select', 'Select',

@ -1121,22 +1121,24 @@ class Fill(PrimitiveWithInfer):
class Ones(PrimitiveWithInfer): class Ones(PrimitiveWithInfer):
""" r"""
Creates a tensor filled with value ones. Creates a tensor filled with value ones.
Creates a tensor with shape described by the first argument and Creates a tensor with shape described by the first argument and
fills it with value ones in type of the second argument. fills it with value ones in type of the second argument.
Inputs: Inputs:
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed. - **shape** (Union[tuple[int], int]) - The specified shape of output tensor.
Only constant positive int is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs: Outputs:
Tensor, has the same type and shape as input value. Tensor, has the same type and shape as input shape value.
Examples: Examples:
>>> from mindspore.ops import operations as P
>>> ones = P.Ones() >>> ones = P.Ones()
>>> output = Ones((2, 2), mindspore.float32) >>> output = ones((2, 2), mindspore.float32)
>>> print(output) >>> print(output)
[[1.0, 1.0], [[1.0, 1.0],
[1.0, 1.0]] [1.0, 1.0]]
@ -1147,40 +1149,46 @@ class Ones(PrimitiveWithInfer):
"""Initialize Fill""" """Initialize Fill"""
def __infer__(self, dims, dtype): def __infer__(self, dims, dtype):
validator.check_value_type("shape", dims['value'], [tuple], self.name) if isinstance(dims['value'], int):
for i, item in enumerate(dims['value']): shape = (dims['value'],)
validator.check_positive_int(item, f'dims[{i}]', self.name) else:
shape = dims['value']
validator.check_value_type("shape", shape, [tuple], self.name)
for i, item in enumerate(shape):
validator.check_non_negative_int(item, shape[i], self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64, mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64] mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value']) x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.ones(dims['value'], x_nptype) ret = np.ones(shape, x_nptype)
out = { out = {
'value': Tensor(ret), 'value': Tensor(ret),
'shape': dims['value'], 'shape': shape,
'dtype': x_nptype, 'dtype': x_nptype,
} }
return out return out
class Zeros(PrimitiveWithInfer): class Zeros(PrimitiveWithInfer):
""" r"""
Creates a tensor filled with value zeros. Creates a tensor filled with value zeros.
Creates a tensor with shape described by the first argument and Creates a tensor with shape described by the first argument and
fills it with value zeros in type of the second argument. fills it with value zeros in type of the second argument.
Inputs: Inputs:
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed. - **shape** (Union[tuple[int], int]) - The specified shape of output tensor.
Only constant positive int is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed. - **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs: Outputs:
Tensor, has the same type and shape as input value. Tensor, has the same type and shape as input shape value.
Examples: Examples:
>>> from mindspore.ops import operations as P
>>> zeros = P.Zeros() >>> zeros = P.Zeros()
>>> output = Zeros((2, 2), mindspore.float32) >>> output = zeros((2, 2), mindspore.float32)
>>> print(output) >>> print(output)
[[0.0, 0.0], [[0.0, 0.0],
[0.0, 0.0]] [0.0, 0.0]]
@ -1192,18 +1200,22 @@ class Zeros(PrimitiveWithInfer):
"""Initialize Fill""" """Initialize Fill"""
def __infer__(self, dims, dtype): def __infer__(self, dims, dtype):
validator.check_value_type("shape", dims['value'], [tuple], self.name) if isinstance(dims['value'], int):
for i, item in enumerate(dims['value']): shape = (dims['value'],)
validator.check_positive_int(item, f'dims[{i}]', self.name) else:
shape = dims['value']
validator.check_value_type("shape", shape, [tuple], self.name)
for i, item in enumerate(shape):
validator.check_non_negative_int(item, shape[i], self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64, valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64, mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64] mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name) validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value']) x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.zeros(dims['value'], x_nptype) ret = np.zeros(shape, x_nptype)
out = { out = {
'value': Tensor(ret), 'value': Tensor(ret),
'shape': dims['value'], 'shape': shape,
'dtype': x_nptype, 'dtype': x_nptype,
} }
return out return out

@ -59,6 +59,13 @@ def test_ones():
assert np.sum(output.asnumpy()) == 6 assert np.sum(output.asnumpy()) == 6
def test_ones_1():
ones = P.Ones()
output = ones(2, mstype.int32)
assert output.asnumpy().shape == (2,)
assert np.sum(output.asnumpy()) == 2
def test_zeros(): def test_zeros():
zeros = P.Zeros() zeros = P.Zeros()
output = zeros((2, 3), mstype.int32) output = zeros((2, 3), mstype.int32)
@ -66,6 +73,13 @@ def test_zeros():
assert np.sum(output.asnumpy()) == 0 assert np.sum(output.asnumpy()) == 0
def test_zeros_1():
zeros = P.Zeros()
output = zeros(2, mstype.int32)
assert output.asnumpy().shape == (2,)
assert np.sum(output.asnumpy()) == 0
@non_graph_engine @non_graph_engine
def test_reshape(): def test_reshape():
input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])) input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))

Loading…
Cancel
Save