!12867 modify Tensor shape check

From: @Somnus2020
Reviewed-by: 
Signed-off-by:
pull/12867/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit fe0185a753

@ -42,8 +42,8 @@ class Tensor(Tensor_):
The argument is used to define the data type of the output tensor. If it is None, the data type of the
output tensor will be as same as the `input_data`. Default: None.
shape (Union[tuple, list, int]): A list of integers, a tuple of integers or an integer as the shape of
output. Default: None.
init (:class:`Initializer`): the information of init data.
output. If `input_data` is available, `shape` doesn't need to be set. Default: None.
init (Initializer): the information of init data.
'init' is used for delayed initialization in parallel mode. Usually, it is not recommended to
use 'init' interface to initialize parameters in other conditions. If 'init' interface is used
to initialize parameters, the `init_data` API need to be called to convert `Tensor` to the actual data.
@ -52,18 +52,26 @@ class Tensor(Tensor_):
Tensor, with the same shape as `input_data`.
Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.nn as nn
>>> from mindspore.common.tensor import Tensor
>>> from mindspore.common.initializer import One
>>> # initialize a tensor with input data
>>> t1 = Tensor(np.zeros([1, 2, 3]), mindspore.float32)
>>> t1 = Tensor(np.zeros([1, 2, 3]), ms.float32)
>>> assert isinstance(t1, Tensor)
>>> assert t1.shape == (1, 2, 3)
>>> assert t1.dtype == mindspore.float32
>>> assert t1.dtype == ms.float32
>>>
>>> # initialize a tensor with a float scalar
>>> t2 = Tensor(0.1)
>>> assert isinstance(t2, Tensor)
>>> assert t2.dtype == mindspore.float64
>>> assert t2.dtype == ms.float64
...
>>> # initialize a tensor with init
>>> t3 = Tensor(shape = (1, 3), dtype=ms.float32, init=One())
>>> assert isinstance(t3, Tensor)
>>> assert t3.shape == (1, 3)
>>> assert t3.dtype == ms.float32
"""
def __init__(self, input_data=None, dtype=None, shape=None, init=None):
@ -71,8 +79,8 @@ class Tensor(Tensor_):
if isinstance(input_data, np_types):
input_data = np.array(input_data)
if input_data is not None and shape is not None and input_data.shape != shape:
raise ValueError("input_data.shape and shape should be same.")
if input_data is not None and shape is not None:
raise ValueError("If input_data is available, shape doesn't need to be set")
if init is not None and (shape is None or dtype is None):
raise ValueError("init, dtype and shape must have values at the same time.")

Loading…
Cancel
Save