!14013 Enable zero dimension using attribute not parameter

From: @liangzhibo
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
pull/14013/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c1a802627b

@ -74,7 +74,7 @@ class Tensor(Tensor_):
>>> assert t3.dtype == ms.float32 >>> assert t3.dtype == ms.float32
""" """
def __init__(self, input_data=None, dtype=None, shape=None, init=None, check_zero_dims=True): def __init__(self, input_data=None, dtype=None, shape=None, init=None):
self.init_finished = False self.init_finished = False
# If input data is numpy number, convert it to np array # If input data is numpy number, convert it to np array
if isinstance(input_data, np_types): if isinstance(input_data, np_types):
@ -92,13 +92,12 @@ class Tensor(Tensor_):
if isinstance(shape, numbers.Number): if isinstance(shape, numbers.Number):
shape = (shape,) shape = (shape,)
if check_zero_dims: if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \
if input_data is not None and isinstance(input_data, (tuple, list, np.ndarray)) \ and np.array(input_data).ndim > 1 and np.array(input_data).size == 0:
and np.array(input_data).ndim > 1 and np.array(input_data).size == 0: raise ValueError("input_data can not contain zero dimension.")
raise ValueError("input_data can not contain zero dimension.") if shape is not None and not (hasattr(init, "__enable_zero_dim__") and init.__enable_zero_dim__):
if shape is not None: if 0 in shape:
if 0 in shape: raise ValueError("Shape can not contain zero value.")
raise ValueError("Shape can not contain zero value.")
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. # If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None: if init is None:

@ -26,6 +26,7 @@ import numbers
import numpy as np import numpy as np
from mindspore import log as logger from mindspore import log as logger
from mindspore.common.initializer import Zero
from .._utils import get_concat_offset from .._utils import get_concat_offset
from ..operations.math_ops import _infer_shape_reduce from ..operations.math_ops import _infer_shape_reduce
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
@ -38,6 +39,7 @@ from ...common.parameter import Parameter
from ...common.tensor import Tensor from ...common.tensor import Tensor
class _ScatterOp(PrimitiveWithInfer): class _ScatterOp(PrimitiveWithInfer):
""" """
Defines Scatter operators Defines Scatter operators
@ -3015,8 +3017,13 @@ class StridedSlice(PrimitiveWithInfer):
ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v) ret_shape = self._compute_slicing_shape(x['shape'], begin_v, end_v, strides_v)
value = None if all(ret_shape) else Tensor(np.array([]).reshape(ret_shape), x['dtype'].element_type(), if all(ret_shape):
check_zero_dims=False) value = None
else:
init_func = Zero()
init_func.__enable_zero_dim__ = True
value = Tensor(dtype=x['dtype'].element_type(), shape=ret_shape, init=init_func)
if "max_value" in x and "min_value" in x: if "max_value" in x and "min_value" in x:
validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name) validator.check_value_type("min_value", x["min_value"], [tuple, list], self.name)
validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name) validator.check_value_type("max_value", x["max_value"], [tuple, list], self.name)

Loading…
Cancel
Save