!10083 support non tensor inputs in model predict

From: @zhangbuxue
Reviewed-by: @zhunaipan,@zh_qh
Signed-off-by: @zh_qh
pull/10083/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 734f50f36e

@ -635,7 +635,7 @@ def check_input_data(*data, data_class):
f' either a single'
f' or a list of {data_class.__name__},'
f' but got part data type is {str(type(item))}.')
if item.size() == 0:
if hasattr(item, "size") and item.size() == 0:
msg = "Please provide non-empty data."
logger.error(msg)
raise ValueError(msg)

@ -724,7 +724,7 @@ class Model:
Batch data should be put together in one tensor.
Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
predict_data: The predict data, can be array, number, str, dict, list or tuple.
Returns:
Tensor, array(s) of predictions.
@ -735,7 +735,7 @@ class Model:
>>> result = model.predict(input_data)
"""
self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor)
check_input_data(*predict_data, data_class=(int, float, str, tuple, list, dict, Tensor))
_parallel_predict_check()
result = self._predict_network(*predict_data)

@ -214,8 +214,8 @@ def test_model_build_abnormal_string():
err = False
try:
model.predict('aaa')
except ValueError as e:
log.error("Find value error: %r ", e)
except TypeError as e:
log.error("Find type error: %r ", e)
err = True
finally:
assert err

Loading…
Cancel
Save