diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index 34ac0521c2..4e4f4b6f41 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -22,15 +22,14 @@ from ..ops.primitive import constexpr from ..nn.layer.basic import tril as nn_tril from ..nn.layer.basic import triu as nn_triu from .._c_expression import Tensor as Tensor_ -from .._c_expression.typing import Float from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \ _expand from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ - _raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \ - _type_convert, _canonicalize_axis, _list_comprehensions, _ceil + _raise_type_error, _expanded_shape, _check_is_float, _iota, _type_convert, \ + _canonicalize_axis, _list_comprehensions, _ceil, _tuple_getitem, _tuple_slice from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to from .dtypes import nan @@ -503,9 +502,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis start, stop = broadcast_arrays(start, stop) axis = _canonicalize_axis(axis, start.ndim+1) bounds_shape = start.shape - bounds_shape = bounds_shape[:axis] + (1,) + bounds_shape[axis:] + bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None) iota_shape = _list_comprehensions(start.ndim+1, 1, True) - iota_shape = iota_shape[:axis] + (num,) + iota_shape[axis+1:] + iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None) num_tensor = _type_convert(Tensor, num).astype(mstype.float32) div = (num_tensor - 1) if endpoint else num_tensor @@ -1542,7 +1541,7 @@ def diag(v, k=0): prod = F.tensor_mul(v, e) cast_type = dtype - if not isinstance(dtype, Float): + if not _check_is_float(dtype): # reduce sum only supports float types cast_type = mstype.float32 prod = F.cast(prod, cast_type) diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index d76d26ff25..90537128c3 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -1809,10 +1809,7 @@ def _check_indices(size, indices, mode): out_of_lowerbounds = F.tensor_lt(indices, lowerbounds) out_of_upperbounds = F.tensor_gt(indices, upperbounds) if mode == 'raise': - # For mode raise, index-out-of-bounds checking is performed at backend since - # evaluation of a boolean scalar Tensor always returns true in graph mode - # regardless of the truth value contained - return indices + _raise_unimplemented_error('"raise" mode is not implemented') if mode == 'wrap': return _mod(indices, F.fill(dtype, shape, size)) zeros = F.fill(dtype, shape, 0) @@ -1821,7 +1818,7 @@ def _check_indices(size, indices, mode): return clipped -def take(a, indices, axis=None, mode='raise'): +def take(a, indices, axis=None, mode='clip'): """ Takes elements from an array along an axis. @@ -1832,6 +1829,7 @@ def take(a, indices, axis=None, mode='raise'): Note: Numpy argument out is not supported. + ``mode = 'raise'`` is not supported, and the default mode is 'clip' instead. Args: a (Tensor): Source array with shape `(Ni…, M, Nk…)`. diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 158e24d1e7..2bd5172df9 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -580,12 +580,12 @@ def minimum(x1, x2, dtype=None): # comparisons with 2 scalars if x1.ndim == 0 and x2.ndim == 0: x1 = expand_dims(x1, 0) - return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype).squeeze() + return _apply_tensor_op(functools.partial(_prop_nan, F.minimum), x1, x2, dtype=dtype).squeeze() if x1.ndim == 0: dtype = x2.dtype elif x2.ndim == 0: dtype = x1.dtype - return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype) + return _apply_tensor_op(functools.partial(_prop_nan, F.minimum), x1, x2, dtype=dtype) def mean(a, axis=None, keepdims=False, dtype=None): @@ -1299,6 +1299,14 @@ def log(x, dtype=None): return _apply_tensor_op(F.log, x, dtype=dtype) +def _prop_nan(fn, x1, x2): + """Selects NaN if either element is NaN""" + has_nan = F.logical_or(_isnan(x1), _isnan(x2)) + nan_tensor = F.fill(_promote(F.dtype(x1), F.dtype(x2)), F.shape(has_nan), nan) + res = fn(x1, x2) + return F.select(has_nan, nan_tensor, res) + + def maximum(x1, x2, dtype=None): """ Returns the element-wise maximum of array elements. @@ -1349,12 +1357,12 @@ def maximum(x1, x2, dtype=None): # F.maximum does not support when both operands are scalar if x1.ndim == 0 and x2.ndim == 0: x1 = expand_dims(x1, 0) - return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype).squeeze() + return _apply_tensor_op(functools.partial(_prop_nan, F.maximum), x1, x2, dtype=dtype).squeeze() if x1.ndim == 0: dtype = x2.dtype elif x2.ndim == 0: dtype = x1.dtype - return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype) + return _apply_tensor_op(functools.partial(_prop_nan, F.maximum), x1, x2, dtype=dtype) def heaviside(x1, x2, dtype=None): @@ -1567,7 +1575,7 @@ def hypot(x1, x2, dtype=None): [[5. 5. 5.] [5. 5. 5.] [5. 5. 5.]] - >>> output = np.hypot(3*np.ones((3, 3)), np.array([4])) + >>> output = np.hypot(3*np.ones((3, 3)), np.array([4.0])) >>> print(output) [[5. 5. 5.] [5. 5. 5.] diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py index aa1bbaafd7..06ce30a2f7 100644 --- a/mindspore/numpy/utils_const.py +++ b/mindspore/numpy/utils_const.py @@ -219,6 +219,22 @@ def _raise_runtime_error(info, param=None): raise RuntimeError(info) raise RuntimeError(info + f"{param}") + +def _raise_unimplemented_error(info, param=None): + """ + Raise NotImplementedError in both graph/pynative mode + + Args: + info(str): info string to display + param(python obj): any object that can be recognized by graph mode. If is + not None, then param's value information will be extracted and displayed. + Default is None. + """ + if param is None: + raise NotImplementedError(info) + raise NotImplementedError(info + f"{param}") + + @constexpr def _empty(dtype, shape): """Returns an uninitialized array with dtype and shape.""" @@ -454,3 +470,8 @@ def _seq_prod(seq1, seq2): def _make_tensor(val, dtype): """ Returns the tensor with value `val` and dtype `dtype`.""" return Tensor(val, dtype) + + +def _tuple_slice(tup, start, end): + """get sliced tuple from start and end.""" + return tup[start:end] diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py index 7ba20ee2a4..c09dc3d688 100644 --- a/tests/st/numpy_native/test_math_ops.py +++ b/tests/st/numpy_native/test_math_ops.py @@ -251,6 +251,16 @@ def test_float_power(): @pytest.mark.env_onecard def test_minimum(): run_binop_test(mnp_minimum, onp_minimum, test_case) + x = onp.random.randint(-10, 10, 20).astype(onp.float32) + y = onp.random.randint(-10, 10, 20).astype(onp.float32) + x[onp.random.randint(0, 10, 3)] = onp.nan + y[onp.random.randint(0, 10, 3)] = onp.nan + x[onp.random.randint(0, 10, 3)] = onp.NINF + y[onp.random.randint(0, 10, 3)] = onp.NINF + x[onp.random.randint(0, 10, 3)] = onp.PINF + y[onp.random.randint(0, 10, 3)] = onp.PINF + match_res(mnp_minimum, onp_minimum, x, y) + match_res(mnp_minimum, onp_minimum, y, x) def mnp_tensordot(x, y): @@ -924,6 +934,16 @@ def onp_maximum(x1, x2): @pytest.mark.env_onecard def test_maximum(): run_binop_test(mnp_maximum, onp_maximum, test_case) + x = onp.random.randint(-10, 10, 20).astype(onp.float32) + y = onp.random.randint(-10, 10, 20).astype(onp.float32) + x[onp.random.randint(0, 10, 3)] = onp.nan + y[onp.random.randint(0, 10, 3)] = onp.nan + x[onp.random.randint(0, 10, 3)] = onp.NINF + y[onp.random.randint(0, 10, 3)] = onp.NINF + x[onp.random.randint(0, 10, 3)] = onp.PINF + y[onp.random.randint(0, 10, 3)] = onp.PINF + match_res(mnp_maximum, onp_maximum, x, y) + match_res(mnp_maximum, onp_maximum, y, x) def mnp_clip(x):