optimize inputs check for predict of model and init data check for Tensor

pull/10179/head
buxue 5 years ago
parent f2b25d4139
commit a6cf444864

@ -623,22 +623,40 @@ def _expand_tuple(n_dimensions):
return convert
def _check_data_type_valid(data, valid_type):
"""Check data type valid."""
if valid_type is None:
return data is None
if isinstance(data, valid_type):
if hasattr(data, 'size') and data.size == 0:
msg = "Please provide non-empty data."
logger.error(msg)
raise ValueError(msg)
return True
return False
def check_input_data(*data, data_class):
"""Input data check."""
for item in data:
if isinstance(item, (list, tuple)):
for v in item:
check_input_data(v, data_class=data_class)
elif isinstance(item, dict):
for v in item.values():
check_input_data(v, data_class=data_class)
else:
if not isinstance(item, data_class):
raise ValueError(f'Please provide as model inputs'
f' either a single'
f' or a list of {data_class.__name__},'
f' but got part data type is {str(type(item))}.')
if hasattr(item, "size") and item.size == 0:
msg = "Please provide non-empty data."
logger.error(msg)
raise ValueError(msg)
if isinstance(data_class, (tuple, list)):
ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class)
else:
ret = _check_data_type_valid(item, data_class)
if not ret:
data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) \
if isinstance(data_class, (tuple, list)) else \
(data_class if data_class is None else data_class.__name__)
raise ValueError(f'Please provide as model inputs either a single or '
f'a tuple or a list or a dict of {data_class_str}, '
f'but got part data type is {item if item is None else type(item).__name__}.')
def check_output_data(data):

@ -235,27 +235,32 @@ def ms_function(fn=None, obj=None, input_signature=None):
Examples:
>>> from mindspore.ops import functional as F
...
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
...
>>> # create a callable MindSpore graph by calling ms_function
>>> def tensor_add(x, y):
... z = x + y
... return z
...
>>> tensor_add_graph = ms_function(fn=tensor_add)
>>> out = tensor_add_graph(x, y)
...
>>> # create a callable MindSpore graph through decorator @ms_function
>>> @ms_function
... def tensor_add_with_dec(x, y):
... z = x + y
... return z
...
>>> out = tensor_add_with_dec(x, y)
...
>>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
>>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)),
... MetaTensor(mindspore.float32, (1, 1, 3, 3))))
... def tensor_add_with_sig(x, y):
... z = x + y
... return z
...
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
...
>>> tensor_add_graph = ms_function(fn=tensor_add)
>>> out = tensor_add_graph(x, y)
>>> out = tensor_add_with_dec(x, y)
>>> out = tensor_add_with_sig(x, y)
"""

@ -71,8 +71,8 @@ class Tensor(Tensor_):
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is a numpy array whose value is {input_data} and "
f"data type is {input_data.dtype} that is not supported to initialize a Tensor.")
raise TypeError(f"For Tensor, the input_data is a numpy array, "
f"but it's data type is not in supported list: {list(i.__name__ for i in valid_dtypes)}.")
if isinstance(input_data, (tuple, list)):
if np.array(input_data).dtype not in valid_dtypes:
raise TypeError(f"For Tensor, the input_data is {input_data} that contain unsupported element.")

@ -61,7 +61,7 @@ def repeat_elements(x, rep, axis=0):
Args:
x (Tensor): The tensor to repeat values for. Must be of type: float16,
float32, int8, uint8, int16, int32, or int64.
float32, int8, uint8, int16, int32, or int64.
rep (int): The number of times to repeat, must be positive, required.
axis (int): The axis along which to repeat, default 0.

@ -4110,7 +4110,7 @@ class BroadcastTo(PrimitiveWithInfer):
Args:
shape (tuple): The target shape to broadcast. Can be fully specified, or have '-1's in one position
where it will be substituted by the input tensor's shape in that position, see example.
where it will be substituted by the input tensor's shape in that position, see example.
Inputs:
- **input_x** (Tensor) - The input tensor.

@ -368,7 +368,7 @@ class MakeRefKey(Primitive):
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.y = Parameter(Tensor(np.ones([6, 8, 10]), mstype.int32), name="y")
... self.y = Parameter(Tensor(np.ones([2, 3]), mstype.int32), name="y")
... self.make_ref_key = ops.MakeRefKey("y")
...
... def construct(self, x):
@ -376,10 +376,12 @@ class MakeRefKey(Primitive):
... ref = ops.make_ref(key, x, self.y)
... return ref * x
...
>>> x = Tensor(np.ones([3, 4, 5]), mstype.int32)
>>> x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]), mindspore.int32)
>>> net = Net()
>>> output = net(x)
>>> print(output)
[[ 1 4 9]
[16 25 36]]
"""
@prim_attr_register

@ -480,7 +480,7 @@ def constexpr(fn=None, get_instance=True, name=None):
... return len(x)
>>> assert tuple_len(a) == 2
...
>>> # make a operator class to calculate tuple len
>>> # make an operator class to calculate tuple len
>>> @constexpr(get_instance=False, name="TupleLen")
>>> def tuple_len_class(x):
... return len(x)

@ -727,7 +727,8 @@ class Model:
Batch data should be put together in one tensor.
Args:
predict_data: The predict data, can be array, number, str, dict, list or tuple.
predict_data: The predict data, can be bool, int, float, str, None, tensor,
or tuple, list and dict that store these types.
Returns:
Tensor, array(s) of predictions.
@ -738,7 +739,7 @@ class Model:
>>> result = model.predict(input_data)
"""
self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=(int, float, str, tuple, list, dict, Tensor))
check_input_data(*predict_data, data_class=(int, float, str, None, Tensor))
_parallel_predict_check()
result = self._predict_network(*predict_data)

Loading…
Cancel
Save