!10097 Fix bugs for np.mean, np.asarray, np.concatenate, np.linspace

From: @yanglf1121
Reviewed-by: 
Signed-off-by:
pull/10097/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 893d59afd7

@ -25,11 +25,14 @@ from ..ops.primitive import constexpr
from .utils import _check_shape, _check_shape_compile, _check_dtype, _check_is_int, \
_check_axes_range, _check_start_normalize, _check_shape_contain_zero, _check_is_tensor, \
_check_input_for_asarray
_check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, _check_is_list, \
_covert_list_tensor_to_tuple_tensor
DEFAULT_FLOAT_DTYPE = mstype.float32
DEFAULT_INT_DTYPE = mstype.int32
# According to official numpy reference, the dimension of a numpy array must be less
# than 32
MAX_NUMPY_DIMS = 32
def array(obj, dtype=None, copy=True, ndmin=0):
"""
@ -115,6 +118,10 @@ def asarray(a, dtype=None):
dtype = mstype.bool_
if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a)
# If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32
@ -175,6 +182,10 @@ def asfarray(a, dtype=DEFAULT_FLOAT_DTYPE):
dtype = DEFAULT_FLOAT_DTYPE
if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = onp.asarray(a)
if isinstance(a, onp.ndarray):
@ -317,8 +328,10 @@ def arange(*args, **kwargs):
implementation.
Args:
start(Union[int, float], optional): Start of interval. The interval includes
this value. Default is 0.
start(Union[int, float]): Start of interval. The interval includes this value.
When stop is provided as a position argument, start must be given, when stop
is a normal argument, start can be optional, and default is 0.
Please see additional examples below.
stop(Union[int, float], optional): End of interval. The interval does not
include this value, except in some cases where step is not an integer
and floating point round-off affects the length of out.
@ -340,6 +353,13 @@ def arange(*args, **kwargs):
>>> import mindspore.numpy as np
>>> print(np.arange(0, 5, 1))
[0 1 2 3 4]
>>> print(np.arange(3))
[0 1 2]
>>> print(np.arange(start=0, stop=3))
[0 1 2]
>>> print(np.arange(0, stop=3, step=0.5))
[0. 0.5 1. 1.5 2. 2.5]
>>> print(np.arange(stop=3)) # This will lead to TypeError
"""
# infer the dtype, if either of start, end, step is float, default dtype is
# float32, else int32.
@ -419,6 +439,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
if isinstance(stop, Tensor):
stop = stop.asnumpy()
if not isinstance(num, int):
raise TypeError(f"num should be an integer, but got {type(num)}")
final_dtype = None
if dtype is not None:
final_dtype = _check_dtype(dtype)
@ -990,7 +1013,7 @@ def concatenate(arrays, axis=0):
# if only one tensor is provided, it is treated as a tuple along the
# first dimension. For example, a tensor of shape (3,4,5) will be treated
# as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5))
if axis is None:
if axis is None or axis >= MAX_NUMPY_DIMS:
return ravel(arrays)
arr_shape = F.shape(arrays)
_check_axes_range((axis,), len(arr_shape))
@ -1002,11 +1025,16 @@ def concatenate(arrays, axis=0):
return arrays
flattened_arrays = ()
if axis is None:
if axis is None or axis >= MAX_NUMPY_DIMS:
for arr in arrays:
flattened_arrays += (ravel(arr),)
axis = -1
return P.Concat(axis)(flattened_arrays)
# convert a list of tensor to a tuple of tensor
if _check_is_list(array_type):
arrays = _covert_list_tensor_to_tuple_tensor(arrays)
arr_shape = F.shape(arrays[0])
_check_axes_range((axis,), len(arr_shape))

@ -15,7 +15,8 @@
"""math operations, the function docs are adapted from Numpy API."""
from ..ops import operations as P
from ..ops import functional as F
from .array_ops import squeeze
from ..ops.primitive import constexpr
from .array_ops import squeeze, asarray
from .utils import _infer_out_shape, _is_scalar, _check_axis_valid, _get_device_compile, \
_check_shape_aligned
@ -63,6 +64,8 @@ def mean(a, axis=None, keepdims=False):
2.5
"""
axis = _check_axis_valid(axis, P.Rank()(a))
if _is_empty(F.shape(a)):
return _nan()
if _is_scalar(a.shape):
if keepdims:
return a
@ -140,6 +143,17 @@ def inner(a, b):
return res
@constexpr
def _nan():
"""Returns a Tensor with nan value"""
return asarray(float('nan'))
def _is_empty(shape):
"""Checks if the shape is empty"""
return F.shape_mul(shape) == 0
def _expand(x, ndim):
"""Expand x to ndim"""
while P.Rank()(x) < ndim:

@ -135,6 +135,37 @@ def _check_dtype(dtype):
return dtype
def _deep_list(array_like):
"""convert nested tuple/list mixtures to pure nested list"""
if isinstance(array_like, (list, tuple)):
return list(map(_deep_list, array_like))
return array_like
def _deep_tensor_to_nparray(array_like):
"""
convert a nested list of tensor to nested list of np_array.
Args:
array_like(list(tensor)): In any format of nested lists that may contain
tensors.
Returns:
array_like(list(np_array)): Formatted array that can be directly processed
by numpy.array(), with all tensor elements converted to numpy_array.
"""
# Recursively check whether each element is a tensor or not, if is tensor,
# convert it to a numpy array in place
if isinstance(array_like, Tensor):
return array_like.asnumpy()
if isinstance(array_like, list):
for idx, value in enumerate(array_like):
array_like[idx] = _deep_tensor_to_nparray(value)
return array_like
def _check_input_for_asarray(array_like):
"""check whether array_like argument is a valid type for np.asarray conversion"""
if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
@ -166,6 +197,14 @@ def _get_device():
return context.get_context('device_target')
def _covert_list_tensor_to_tuple_tensor(list_of_tensor):
"""Convert a list of tensor to a tuple of tensor"""
tuple_of_tensor = ()
for tensor in list_of_tensor:
tuple_of_tensor += (tensor,)
return tuple_of_tensor
def _get_mode():
"""Get the current mode (0 is Graph mode, 1 is PyNative mode)"""
return context.get_context('mode')

@ -85,6 +85,14 @@ def test_asarray():
expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy()
match_array(actual, expected, error=7)
# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
match_array(actual, expected, error=7)
def test_array():
# array's function is very similar to asarray, so we mainly test the
@ -100,6 +108,14 @@ def test_array():
assert arr1 is not arr3
assert arr4 is arr5
# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
match_array(actual, expected, error=7)
def test_asfarray():
test_case = Cases()
@ -120,6 +136,14 @@ def test_asfarray():
expected = mnp.asfarray(array, test_case.mnp_dtypes[i]).asnumpy()
match_array(actual, expected, error=7)
# Additional tests for nested tensor/numpy_array mixture
mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]]
actual = onp.asarray(onp_input)
expected = mnp.asarray(mnp_input).asnumpy()
match_array(actual, expected, error=7)
def test_zeros():
test_case = Cases()
@ -547,20 +571,9 @@ def test_exec():
return test_exec_case
raise_set = [
('Expand_dims_Error', {
'block': (lambda x: mnp.expand_dims, {'exception': ValueError}),
'desc_inputs': [mnp.ones((2, 3, 4)), 0]}),
]
def expand_dims_exception(input_tensor):
return mnp.expand_dims(input_tensor, 1.2)
def test_expand_dims_exception():
with pytest.raises(TypeError):
expand_dims_exception(mnp.ones((3, 3)))
mnp.expand_dims(mnp.ones((3, 3)), 1.2)
def test_asarray_exception():
@ -568,10 +581,11 @@ def test_asarray_exception():
mnp.asarray({1, 2, 3})
def swapaxes_exception(input_tensor):
return mnp.swapaxes(input_tensor, 1, 10)
def test_swapaxes_exception():
with pytest.raises(TypeError):
mnp.swapaxes(mnp.ones((3, 3)), 1, 10)
def test_swapaxes_exception():
def test_linspace_exception():
with pytest.raises(TypeError):
swapaxes_exception(mnp.ones((3, 3)))
mnp.linspace(0, 1, num=2.5)

Loading…
Cancel
Save