Add Ones and Zeros operators

pull/8510/head
l00591931 4 years ago
parent c5b5a6719c
commit 886ef520d7

@ -22,7 +22,7 @@ A collection of operators to build neural networks or to compute functions.
from .image_ops import (CropAndResize) from .image_ops import (CropAndResize)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye, Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, Fill, Ones, Zeros, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Meshgrid,
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,

@ -998,6 +998,93 @@ class Fill(PrimitiveWithInfer):
return out return out
class Ones(PrimitiveWithInfer):
"""
Creates a tensor filled with value ones.
Creates a tensor with shape described by the first argument and
fills it with value ones in type of the second argument.
Inputs:
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs:
Tensor, has the same type and shape as input value.
Examples:
>>> ones = P.Ones()
>>> Ones((2, 2), mindspore.float32)
[[1.0, 1.0],
[1.0, 1.0]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Fill"""
def __infer__(self, dims, dtype):
validator.check_value_type("shape", dims['value'], [tuple], self.name)
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.ones(dims['value'], x_nptype)
out = {
'value': Tensor(ret),
'shape': dims['value'],
'dtype': x_nptype,
}
return out
class Zeros(PrimitiveWithInfer):
"""
Creates a tensor filled with value zeros.
Creates a tensor with shape described by the first argument and
fills it with value zeros in type of the second argument.
Inputs:
- **shape** (tuple) - The specified shape of output tensor. Only constant value is allowed.
- **type** (mindspore.dtype) - The specified type of output tensor. Only constant value is allowed.
Outputs:
Tensor, has the same type and shape as input value.
Examples:
>>> zeros = P.Zeros()
>>> Zeros((2, 2), mindspore.float32)
[[0.0, 0.0],
[0.0, 0.0]]
"""
@prim_attr_register
def __init__(self):
"""Initialize Fill"""
def __infer__(self, dims, dtype):
validator.check_value_type("shape", dims['value'], [tuple], self.name)
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
mstype.uint8, mstype.uint32, mstype.uint64,
mstype.float16, mstype.float32, mstype.float64]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
ret = np.zeros(dims['value'], x_nptype)
out = {
'value': Tensor(ret),
'shape': dims['value'],
'dtype': x_nptype,
}
return out
class OnesLike(PrimitiveWithInfer): class OnesLike(PrimitiveWithInfer):
""" """
Creates a new tensor. The values of all elements are 1. Creates a new tensor. The values of all elements are 1.

@ -52,6 +52,20 @@ def test_cast():
assert np.all(result.asnumpy() == expect) assert np.all(result.asnumpy() == expect)
def test_ones():
ones = P.Ones()
output = ones((2, 3), mstype.int32)
assert output.asnumpy().shape == (2, 3)
assert np.sum(output.asnumpy()) == 6
def test_zeros():
zeros = P.Zeros()
output = zeros((2, 3), mstype.int32)
assert output.asnumpy().shape == (2, 3)
assert np.sum(output.asnumpy()) == 0
@non_graph_engine @non_graph_engine
def test_reshape(): def test_reshape():
input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])) input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))

Loading…
Cancel
Save