From: @Somnus2020
Reviewed-by: 
Signed-off-by:
pull/11177/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit fd6dc1b060

@ -17,13 +17,13 @@ from . import dtype
from .api import ms_function
from .dtype import *
from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
from .tensor import Tensor, RowTensor, SparseTensor
from .seed import set_seed, get_seed
__all__ = dtype.__all__
__all__.extend([
"MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor
"Tensor", "RowTensor", "SparseTensor", # tensor
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype",

@ -100,7 +100,6 @@ class Parameter(Tensor_):
... def construct(self, x):
... out = self.matmul(self.weight, x)
... return out
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
>>> net = Net()
>>> x = Tensor(np.ones((2,1)))
>>> print(net(x))
@ -113,15 +112,14 @@ class Parameter(Tensor_):
__base_type__ = {}
def __new__(cls, default_input, *args, **kwargs):
init_data_flag = bool(isinstance(default_input, Tensor) and default_input.has_init)
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
new_type = Parameter._get_base_class(input_class)
obj = input_class.__new__(new_type)
input_class.__init__(obj, *class_init_args)
# it's better to make the Initializer a kind of tensor.
obj.init_mode = None
obj.is_default_input_init = False
if isinstance(default_input, Tensor) and default_input.has_init:
obj.is_default_input_init = True
obj.is_default_input_init = init_data_flag
if obj.has_init:
obj.init_mode = default_input
return obj

@ -19,11 +19,10 @@ from mindspore import log as logger
from mindspore.communication.management import get_rank, get_group_size
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
from .._c_expression import MetaTensor
from .._c_expression import Tensor as Tensor_
from .._checkparam import Validator as validator
__all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
__all__ = ['Tensor', 'RowTensor', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)
@ -41,6 +40,9 @@ class Tensor(Tensor_):
dtype (:class:`mindspore.dtype`): Input data should be None, bool or numeric type defined in `mindspore.dtype`.
The argument is used to define the data type of the output tensor. If it is None, the data type of the
output tensor will be as same as the `input_data`. Default: None.
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
init (class:'Initializer'): the information of init data.
Outputs:
Tensor, with the same shape as `input_data`.
@ -65,6 +67,12 @@ class Tensor(Tensor_):
if isinstance(input_data, np_types):
input_data = np.array(input_data)
if input_data is not None and shape is not None and input_data.shape != shape:
raise ValueError("input_data.shape and shape should be same.")
if init is not None and (shape is None or dtype is None):
raise ValueError("init, dtype and shape must have values at the same time.")
if ((input_data is not None and init is None) or (input_data is None and init is not None)) is False:
raise TypeError("input_data and init can not be None at the same time.")
@ -306,10 +314,12 @@ class Tensor(Tensor_):
def asnumpy(self):
"""Convert tensor to numpy array."""
self.init_check()
return Tensor_.asnumpy(self)
def _flush_from_cache(self):
"""Flush cache data to host if tensor is cache enable."""
self.init_check()
Tensor_._flush_from_cache(self)
def all(self, axis=(), keep_dims=False):
@ -327,6 +337,7 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
"""
self.init_check()
if axis is None:
axis = ()
return tensor_operator_registry.get('all')(keep_dims)(self, axis)
@ -346,6 +357,7 @@ class Tensor(Tensor_):
Tensor, has the same data type as x.
"""
self.init_check()
if axis is None:
axis = ()
return tensor_operator_registry.get('any')(keep_dims)(self, axis)
@ -360,6 +372,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same dimension as the input shape.
"""
self.init_check()
if not shape:
raise ValueError("The shape variable should not be empty")
if isinstance(shape[0], tuple):
@ -379,6 +392,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same dimension as input tensor.
"""
self.init_check()
return tensor_operator_registry.get('broadcast_to')(x.shape)(self)
def abs(self):
@ -388,6 +402,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same data type as x.
"""
self.init_check()
return tensor_operator_registry.get('abs')()(self)
def mean(self, axis=(), keep_dims=False):
@ -404,6 +419,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same data type as x.
"""
self.init_check()
if axis is None:
axis = ()
return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
@ -429,6 +445,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same dimension as input tensor, with axes suitably permuted.
"""
self.init_check()
perm = validator.check_transpose_axis(axes, self.ndim)
return tensor_operator_registry.get('transpose')()(self, perm)
@ -446,6 +463,7 @@ class Tensor(Tensor_):
reshaped_tensor(Tensor): This will be a new view object if possible;
otherwise, it will be a copy.
"""
self.init_check()
new_shape = validator.check_reshape_shp(shape)
return tensor_operator_registry.get('reshape')()(self, new_shape)
@ -457,6 +475,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same data type as x.
"""
self.init_check()
reshape_op = tensor_operator_registry.get('reshape')()
return reshape_op(self, (-1,))
@ -472,6 +491,7 @@ class Tensor(Tensor_):
Returns:
Tensor, has the same data type as x.
"""
self.init_check()
reshape_op = tensor_operator_registry.get('reshape')()
trans_op = tensor_operator_registry.get('transpose')()
@ -493,6 +513,7 @@ class Tensor(Tensor_):
Returns:
Transposed tensor, has the same data type as the original tensor x.
"""
self.init_check()
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
if axis1 == axis2:
@ -521,6 +542,7 @@ class Tensor(Tensor_):
Returns:
Tensor, with all or a subset of the dimensions of length 1 removed.
"""
self.init_check()
if axis is None:
return tensor_operator_registry.get('squeeze')(self)
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
@ -542,12 +564,18 @@ class Tensor(Tensor_):
Returns:
Tensor, with the designated dtype.
"""
self.init_check()
dtype = validator.check_astype_dtype(dtype)
if not copy and dtype == self.dtype:
return self
return tensor_operator_registry.get('cast')(self, dtype)
def init_check(self):
if self.has_init:
self.init_data()
return self
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
"""
Get the tensor format data of this Tensor.
@ -601,7 +629,9 @@ class Tensor(Tensor_):
rank = get_rank(opt_shard_group)
size = get_group_size(opt_shard_group)
data = np.split(data, size)[rank]
return Tensor(data, dtype=self.dtype)
self.init = None
self.assign_value(Tensor(data, dtype=self.dtype))
return self
def to_tensor(self, slice_index=None, shape=None, opt_shard_group=None):

@ -16,6 +16,7 @@
import numpy as np
import mindspore as ms
import mindspore.common.initializer as init
from mindspore.common.api import _executor
from mindspore.nn import Cell
from mindspore.ops import operations as P
@ -99,6 +100,12 @@ def test_asnumpy():
assert a.asnumpy().all() == npd.all()
def test_initializer_asnumpy():
npd = np.ones((2, 3))
a = init.initializer('one', [2, 3], ms.int32)
assert a.asnumpy().all() == npd.all()
def test_print():
a = ms.Tensor(np.ones((2, 3)))
a.set_dtype(ms.int32)

@ -13,7 +13,9 @@
# limitations under the License.
# ============================================================================
""" test expand_as"""
import mindspore as ms
import mindspore.nn as nn
import mindspore.common.initializer as init
from mindspore import Tensor
from mindspore import context
@ -34,6 +36,20 @@ def test_expand_as():
net()
def test_initializer_expand_as():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.t1 = init.initializer('one', [1, 3], ms.float32)
self.t2 = init.initializer('one', [2, 3], ms.float32)
def construct(self):
return self.t1.expand_as(self.t2)
net = Net()
net()
def test_expand_as_parameter():
class Net(nn.Cell):
def __init__(self):

@ -15,7 +15,9 @@
""" test view"""
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.common.initializer as init
from mindspore import Tensor
from mindspore import context
@ -35,6 +37,19 @@ def test_view():
net()
def test_view_initializer():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = init.initializer('normal', [2, 3], ms.float32)
def construct(self):
return self.value.view(-1)
net = Net()
net()
def test_view_1():
class Net(nn.Cell):
def __init__(self):

Loading…
Cancel
Save