!12915 numpy-native fix isscalar in graph mode && add typecheck for initial, repeats

From: @jachua
Reviewed-by: @guoqi1024,@liangchenghui
Signed-off-by: @guoqi1024
pull/12915/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 515b6f66e8

@ -1596,7 +1596,7 @@ def ix_(*args):
Boolean masks are not supported. Boolean masks are not supported.
Args: Args:
*args (Tensor): 1-D, each sequence should be of integer type. *args (Tensor): 1-D sequences.
Returns: Returns:
Tuple of Tensor, `N` arrays with `N` dimensions each, with `N` the Tuple of Tensor, `N` arrays with `N` dimensions each, with `N` the

@ -1898,15 +1898,17 @@ def repeat(a, repeats, axis=None):
[3 4]] [3 4]]
""" """
_check_input_tensor(a) _check_input_tensor(a)
if not isinstance(repeats, (tuple, list)):
repeats = (repeats,)
_check_element_int(repeats)
if axis is None: if axis is None:
a = ravel(a) a = ravel(a)
axis = 0 axis = 0
ndim = F.rank(a) ndim = F.rank(a)
_check_axis_in_range(axis, ndim) _check_axis_in_range(axis, ndim)
axis = axis + ndim if axis < 0 else axis axis = axis + ndim if axis < 0 else axis
if isinstance(repeats, (tuple, list)) and len(repeats) == 1: if len(repeats) == 1:
repeats = repeats[0] repeats = repeats[0]
if isinstance(repeats, int):
if repeats == 0: if repeats == 0:
return _empty(F.dtype(a), (0,)) return _empty(F.dtype(a), (0,))
return C.repeat_elements(a, repeats, axis) return C.repeat_elements(a, repeats, axis)

@ -17,7 +17,9 @@
from .math_ops import _apply_tensor_op from .math_ops import _apply_tensor_op
from ..ops import functional as F from ..ops import functional as F
from ..ops.primitive import constexpr
from ..common import dtype as mstype from ..common import dtype as mstype
from ..common import Tensor
from .._c_expression import typing from .._c_expression import typing
from .array_creations import zeros, ones from .array_creations import zeros, ones
@ -530,6 +532,13 @@ def isneginf(x):
return _is_sign_inf(x, F.tensor_lt) return _is_sign_inf(x, F.tensor_lt)
@constexpr
def _isscalar(x):
"""Returns True if x is a scalar type"""
return isinstance(x, (typing.Number, typing.Int, typing.UInt, typing.Float,
typing.Bool, typing.String))
def isscalar(element): def isscalar(element):
""" """
Returns True if the type of element is a scalar type. Returns True if the type of element is a scalar type.
@ -565,5 +574,5 @@ def isscalar(element):
>>> print(output) >>> print(output)
True True
""" """
return isinstance(F.typeof(element), (typing.Number, typing.Int, typing.UInt, obj_type = F.typeof(element)
typing.Float, typing.Bool, typing.String)) return not isinstance(obj_type, Tensor) and _isscalar(obj_type)

@ -2237,6 +2237,10 @@ def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where
ndim = F.rank(a) ndim = F.rank(a)
dtype = F.dtype(a) dtype = F.dtype(a)
axes = _check_axis_valid(axis, ndim) axes = _check_axis_valid(axis, ndim)
if initial is not None:
if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
not isinstance(initial, (int, float, bool, Tensor))):
_raise_type_error('initial should be scalar')
if _is_shape_empty(shape): if _is_shape_empty(shape):
if not axes: if not axes:

Loading…
Cancel
Save