From 2ce5de2f47a8acbf3ddf48a77b31a0bf6c4cf1ce Mon Sep 17 00:00:00 2001 From: l00591931 Date: Wed, 7 Apr 2021 17:30:36 +0800 Subject: [PATCH] Change Tensor zero dimension check to make it faster --- mindspore/common/tensor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 73060e5f93..4de33b82b6 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -92,9 +92,13 @@ class Tensor(Tensor_): if isinstance(shape, numbers.Number): shape = (shape,) - if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \ - and np.array(input_data).ndim > 1 and np.array(input_data).size == 0: - raise ValueError("input_data can not contain zero dimension.") + if input_data is not None: + if isinstance(input_data, np.ndarray) and input_data.ndim > 1 and input_data.size == 0: + raise ValueError("input_data can not contain zero dimension.") + if isinstance(input_data, (tuple, list)) and np.array(input_data).ndim > 1 \ + and np.array(input_data).size == 0: + raise ValueError("input_data can not contain zero dimension.") + if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__): if 0 in shape: raise ValueError("Shape can not contain zero value.")