!13310 numpy-native fix linspace, diag, mximum, minimum

From: @jachua
Reviewed-by: @guoqi1024,@liangchenghui
Signed-off-by: @liangchenghui
pull/13310/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7c4d7a89e1

@ -22,15 +22,14 @@ from ..ops.primitive import constexpr
from ..nn.layer.basic import tril as nn_tril
from ..nn.layer.basic import triu as nn_triu
from .._c_expression import Tensor as Tensor_
from .._c_expression.typing import Float
from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \
_expand
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
_raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \
_type_convert, _canonicalize_axis, _list_comprehensions, _ceil
_raise_type_error, _expanded_shape, _check_is_float, _iota, _type_convert, \
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_getitem, _tuple_slice
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to
from .dtypes import nan
@ -503,9 +502,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
start, stop = broadcast_arrays(start, stop)
axis = _canonicalize_axis(axis, start.ndim+1)
bounds_shape = start.shape
bounds_shape = bounds_shape[:axis] + (1,) + bounds_shape[axis:]
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
iota_shape = iota_shape[:axis] + (num,) + iota_shape[axis+1:]
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
div = (num_tensor - 1) if endpoint else num_tensor
@ -1542,7 +1541,7 @@ def diag(v, k=0):
prod = F.tensor_mul(v, e)
cast_type = dtype
if not isinstance(dtype, Float):
if not _check_is_float(dtype):
# reduce sum only supports float types
cast_type = mstype.float32
prod = F.cast(prod, cast_type)

@ -1809,10 +1809,7 @@ def _check_indices(size, indices, mode):
out_of_lowerbounds = F.tensor_lt(indices, lowerbounds)
out_of_upperbounds = F.tensor_gt(indices, upperbounds)
if mode == 'raise':
# For mode raise, index-out-of-bounds checking is performed at backend since
# evaluation of a boolean scalar Tensor always returns true in graph mode
# regardless of the truth value contained
return indices
_raise_unimplemented_error('"raise" mode is not implemented')
if mode == 'wrap':
return _mod(indices, F.fill(dtype, shape, size))
zeros = F.fill(dtype, shape, 0)
@ -1821,7 +1818,7 @@ def _check_indices(size, indices, mode):
return clipped
def take(a, indices, axis=None, mode='raise'):
def take(a, indices, axis=None, mode='clip'):
"""
Takes elements from an array along an axis.
@ -1832,6 +1829,7 @@ def take(a, indices, axis=None, mode='raise'):
Note:
Numpy argument out is not supported.
``mode = 'raise'`` is not supported, and the default mode is 'clip' instead.
Args:
a (Tensor): Source array with shape `(Ni, M, Nk)`.

@ -580,12 +580,12 @@ def minimum(x1, x2, dtype=None):
# comparisons with 2 scalars
if x1.ndim == 0 and x2.ndim == 0:
x1 = expand_dims(x1, 0)
return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype).squeeze()
return _apply_tensor_op(functools.partial(_prop_nan, F.minimum), x1, x2, dtype=dtype).squeeze()
if x1.ndim == 0:
dtype = x2.dtype
elif x2.ndim == 0:
dtype = x1.dtype
return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype)
return _apply_tensor_op(functools.partial(_prop_nan, F.minimum), x1, x2, dtype=dtype)
def mean(a, axis=None, keepdims=False, dtype=None):
@ -1299,6 +1299,14 @@ def log(x, dtype=None):
return _apply_tensor_op(F.log, x, dtype=dtype)
def _prop_nan(fn, x1, x2):
"""Selects NaN if either element is NaN"""
has_nan = F.logical_or(_isnan(x1), _isnan(x2))
nan_tensor = F.fill(_promote(F.dtype(x1), F.dtype(x2)), F.shape(has_nan), nan)
res = fn(x1, x2)
return F.select(has_nan, nan_tensor, res)
def maximum(x1, x2, dtype=None):
"""
Returns the element-wise maximum of array elements.
@ -1349,12 +1357,12 @@ def maximum(x1, x2, dtype=None):
# F.maximum does not support when both operands are scalar
if x1.ndim == 0 and x2.ndim == 0:
x1 = expand_dims(x1, 0)
return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype).squeeze()
return _apply_tensor_op(functools.partial(_prop_nan, F.maximum), x1, x2, dtype=dtype).squeeze()
if x1.ndim == 0:
dtype = x2.dtype
elif x2.ndim == 0:
dtype = x1.dtype
return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype)
return _apply_tensor_op(functools.partial(_prop_nan, F.maximum), x1, x2, dtype=dtype)
def heaviside(x1, x2, dtype=None):
@ -1567,7 +1575,7 @@ def hypot(x1, x2, dtype=None):
[[5. 5. 5.]
[5. 5. 5.]
[5. 5. 5.]]
>>> output = np.hypot(3*np.ones((3, 3)), np.array([4]))
>>> output = np.hypot(3*np.ones((3, 3)), np.array([4.0]))
>>> print(output)
[[5. 5. 5.]
[5. 5. 5.]

@ -219,6 +219,22 @@ def _raise_runtime_error(info, param=None):
raise RuntimeError(info)
raise RuntimeError(info + f"{param}")
def _raise_unimplemented_error(info, param=None):
"""
Raise NotImplementedError in both graph/pynative mode
Args:
info(str): info string to display
param(python obj): any object that can be recognized by graph mode. If is
not None, then param's value information will be extracted and displayed.
Default is None.
"""
if param is None:
raise NotImplementedError(info)
raise NotImplementedError(info + f"{param}")
@constexpr
def _empty(dtype, shape):
"""Returns an uninitialized array with dtype and shape."""
@ -454,3 +470,8 @@ def _seq_prod(seq1, seq2):
def _make_tensor(val, dtype):
""" Returns the tensor with value `val` and dtype `dtype`."""
return Tensor(val, dtype)
def _tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]

@ -251,6 +251,16 @@ def test_float_power():
@pytest.mark.env_onecard
def test_minimum():
run_binop_test(mnp_minimum, onp_minimum, test_case)
x = onp.random.randint(-10, 10, 20).astype(onp.float32)
y = onp.random.randint(-10, 10, 20).astype(onp.float32)
x[onp.random.randint(0, 10, 3)] = onp.nan
y[onp.random.randint(0, 10, 3)] = onp.nan
x[onp.random.randint(0, 10, 3)] = onp.NINF
y[onp.random.randint(0, 10, 3)] = onp.NINF
x[onp.random.randint(0, 10, 3)] = onp.PINF
y[onp.random.randint(0, 10, 3)] = onp.PINF
match_res(mnp_minimum, onp_minimum, x, y)
match_res(mnp_minimum, onp_minimum, y, x)
def mnp_tensordot(x, y):
@ -924,6 +934,16 @@ def onp_maximum(x1, x2):
@pytest.mark.env_onecard
def test_maximum():
run_binop_test(mnp_maximum, onp_maximum, test_case)
x = onp.random.randint(-10, 10, 20).astype(onp.float32)
y = onp.random.randint(-10, 10, 20).astype(onp.float32)
x[onp.random.randint(0, 10, 3)] = onp.nan
y[onp.random.randint(0, 10, 3)] = onp.nan
x[onp.random.randint(0, 10, 3)] = onp.NINF
y[onp.random.randint(0, 10, 3)] = onp.NINF
x[onp.random.randint(0, 10, 3)] = onp.PINF
y[onp.random.randint(0, 10, 3)] = onp.PINF
match_res(mnp_maximum, onp_maximum, x, y)
match_res(mnp_maximum, onp_maximum, y, x)
def mnp_clip(x):

Loading…
Cancel
Save