diff --git a/mindspore/numpy/__init__.py b/mindspore/numpy/__init__.py index f14dca72d8..9af22826b7 100644 --- a/mindspore/numpy/__init__.py +++ b/mindspore/numpy/__init__.py @@ -30,13 +30,14 @@ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, res ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d, column_stack, hstack, dstack, vstack, stack, unique, moveaxis, tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit, - flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat) + flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat, + rot90, select, array_split) from .array_creations import copy_ as copy from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, linspace, logspace, eye, identity, empty, empty_like, ones_like, zeros_like, full_like, diagonal, tril, triu, tri, trace, meshgrid, mgrid, ogrid, diagflat, - diag, diag_indices, ix_) + diag, diag_indices, ix_, indices, geomspace, vander) from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, numeric_types, PINF, NINF) @@ -45,35 +46,51 @@ from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, - exp, expm1, cumsum) + exp, expm1, exp2, kron, promote_types, divmod_, diff, cbrt, + cross, ceil, trapz, gcd, lcm, convolve, log1p, logaddexp, log2, + logaddexp2, log10, ediff1d, nansum, nanmean, nanvar, nanstd, cumsum, nancumsum, + sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, arcsinh, arccosh, + arctanh, arctan2, cov) from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, - isnan, isinf, isposinf, isneginf, isscalar) + isnan, isinf, isposinf, isneginf, isscalar, logical_and, logical_not, + logical_or, logical_xor, in1d, isin, isclose) mod = remainder fabs = absolute +divmod = divmod_ # pylint: disable=redefined-builtin +abs = absolute # pylint: disable=redefined-builtin +max = amax # pylint: disable=redefined-builtin +min = amin # pylint: disable=redefined-builtin + array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape', 'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d', 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis', 'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit', 'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take', - 'repeat'] + 'repeat', 'rot90', 'select', 'array_split'] array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', - 'diag_indices', 'ix_', 'cumsum'] + 'diag_indices', 'ix_', 'indices', 'geomspace', 'vander'] math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power', 'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal', 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', - 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'cumsum'] + 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs', 'exp2', 'kron', + 'promote_types', 'divmod', 'diff', 'cbrt', 'cross', 'ceil', 'trapz', + 'abs', 'max', 'min', 'gcd', 'lcm', 'log1p', 'logaddexp', 'log2', 'logaddexp2', 'log10', + 'convolve', 'ediff1d', 'nansum', 'nanmean', 'nanvar', 'nanstd', 'cumsum', + 'nancumsum', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan', 'sinh', 'cosh', 'tanh', + 'arcsinh', 'arccosh', 'arctanh', 'arctan2', 'cov'] logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', - 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] + 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar', 'logical_and', 'logical_not', + 'logical_or', 'logical_xor', 'in1d', 'isin', 'isclose'] __all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index c7837685d8..34ac0521c2 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================ """array operations, the function docs are adapted from Numpy API.""" -from copy import deepcopy - import numpy as onp from ..common import Tensor @@ -27,10 +25,11 @@ from .._c_expression import Tensor as Tensor_ from .._c_expression.typing import Float from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ - _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar + _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \ + _expand from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ - _raise_type_error, _expanded_shape, _check_is_float, _iota, \ + _raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \ _type_convert, _canonicalize_axis, _list_comprehensions, _ceil from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to from .dtypes import nan @@ -49,9 +48,8 @@ def array(obj, dtype=None, copy=True, ndmin=0): This function creates tensors from an array-like object. Args: - obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a `Tensor`. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + obj (Union[int, float, bool, list, tuple]): Input data, in any form that + can be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type of the new tensor will be inferred from obj. Default is :class:`None`. @@ -76,17 +74,53 @@ def array(obj, dtype=None, copy=True, ndmin=0): >>> print(np.array([1,2,3])) [1 2 3] """ - if ndmin > 0: - # Fall back to original numpy creation. - if isinstance(obj, Tensor): - obj = obj.asnumpy() - return asarray(onp.array(obj, dtype, copy=copy, ndmin=ndmin)) + res = asarray(obj, dtype) + if ndmin > res.ndim: + res = _expand(res, ndmin) + + if copy: + res = copy_(res) + elif dtype is not None and dtype != res.dtype: + res = res.astype(dtype) + + return res + + +@constexpr +def asarray_const(a, dtype=None): + """Converts the input to tensor. Note here `a` cannot be tensor itself.""" + _check_input_for_asarray(a) + + if dtype is not None: + dtype = _check_dtype(dtype) + + if isinstance(a, (float, int, bool)) and dtype is None: + dtype = _get_dtype_from_scalar(a) + + if isinstance(a, (list, tuple)): + # Convert all tuple/nested tuples to lists + a = _deep_list(a) + # Convert all tensor sub-elements to numpy arrays + a = _deep_tensor_to_nparray(a) + a = onp.asarray(a) + if a.dtype is onp.dtype('object'): + raise ValueError('Input array must have the same size across all dimensions.') + # If dtype is not specified, we keep consistent with numpy decision + # only exceptions are: we use int/float32 + if dtype is None: + dtype = mstype.pytype_to_dtype(a.dtype) + if dtype == mstype.float64: + dtype = mstype.float32 + elif dtype == mstype.int64: + dtype = mstype.int32 - if not copy: - return asarray(obj, dtype=dtype) + if isinstance(a, onp.ndarray) and dtype is None: + if a.dtype is onp.dtype('object'): + raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") + dtype = mstype.pytype_to_dtype(a.dtype) + a = Tensor.from_numpy(a) - obj = deepcopy(obj) - return asarray(obj, dtype=dtype) + return Tensor(a, dtype=dtype) def asarray(a, dtype=None): @@ -96,9 +130,8 @@ def asarray(a, dtype=None): This function converts tensors from an array-like object. Args: - a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a `Tensor`. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can + be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type of the new tensor will be inferred from obj. Default is :class:`None`. @@ -118,46 +151,28 @@ def asarray(a, dtype=None): >>> print(np.asarray([1,2,3])) [1 2 3] """ - _check_input_for_asarray(a) - - if dtype is not None: - dtype = _check_dtype(dtype) + if isinstance(a, Tensor): + if dtype is None or dtype == a.dtype: + return a + return a.astype(dtype) + return asarray_const(a, dtype) - if isinstance(a, (float, int, bool)) and dtype is None: - dtype = _get_dtype_from_scalar(a) +@constexpr +def asfarray_const(a, dtype=mstype.float32): + """Converts the input to tensor. Note here `a` cannot be tensor itself.""" + _check_input_for_asarray(a) if isinstance(a, (list, tuple)): # Convert all tuple/nested tuples to lists a = _deep_list(a) # Convert all tensor sub-elements to numpy arrays a = _deep_tensor_to_nparray(a) a = onp.asarray(a) - if a.dtype is onp.dtype('object'): - raise ValueError('Input array must have the same size across all dimensions.') - # If dtype is not specified, we keep consistent with numpy decision - # only exceptions are: we use int/float32 - if dtype is None: - dtype = mstype.pytype_to_dtype(a.dtype) - if dtype == mstype.float64: - dtype = mstype.float32 - elif dtype == mstype.int64: - dtype = mstype.int32 - - if isinstance(a, onp.ndarray) and dtype is None: if a.dtype is onp.dtype('object'): raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") - dtype = mstype.pytype_to_dtype(a.dtype) a = Tensor.from_numpy(a) - # If a is already a tensor and we don't need to cast dtype, return a - if isinstance(a, Tensor): - if dtype is None or dtype == a.dtype: - return a - - return Tensor(a, dtype=dtype) - - -asarray_const = constexpr(asarray) + return Tensor(a, dtype) def asfarray(a, dtype=mstype.float32): @@ -167,9 +182,8 @@ def asfarray(a, dtype=mstype.float32): If non-float dtype is defined, this function will return a float32 tensor instead. Args: - a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a `Tensor`. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can + be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. @@ -190,27 +204,18 @@ def asfarray(a, dtype=mstype.float32): >>> print(np.asfarray([1,2,3])) [1. 2. 3.] """ - _check_input_for_asarray(a) - if dtype is None: return asarray(a) dtype = _check_dtype(dtype) - if dtype not in (mstype.float16, mstype.float32, mstype.float64): + # pylint: disable=consider-using-in + if dtype != mstype.float16 and dtype != mstype.float32 and dtype != mstype.float64: dtype = mstype.float32 - if isinstance(a, (list, tuple)): - # Convert all tuple/nested tuples to lists - a = _deep_list(a) - # Convert all tensor sub-elements to numpy arrays - a = _deep_tensor_to_nparray(a) - a = onp.asarray(a) - if a.dtype is onp.dtype('object'): - raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") - if isinstance(a, onp.ndarray): - a = Tensor.from_numpy(a) + if isinstance(a, Tensor): + return a.astype(dtype) - return Tensor(a, dtype) + return asfarray_const(a) def copy_(a): @@ -218,9 +223,8 @@ def copy_(a): Returns a tensor copy of the given object. Args: - a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + a (Union[int, float, bool, list, tuple, Tensor]): Input data, in any form that can + be converted to a `Tensor`. This includes Tensor, list, tuple and numbers. Returns: Tensor, has the same data as `a`. @@ -241,8 +245,16 @@ def copy_(a): """ if not isinstance(a, Tensor): a = asarray_const(a) - return a.copy() - + # The current implementation registers a new memory location for copied tensor by + # doing some reduandent operations. + origin_dtype = a.dtype + if origin_dtype == mstype.bool_: + return F.logical_not(F.logical_not(a)) + if origin_dtype != mstype.float64: + a = a.astype("float32") + a = a / ones_like(a) + a = a.astype(origin_dtype) + return a def ones(shape, dtype=mstype.float32): """ @@ -566,6 +578,65 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): return F.tensor_pow(base, linspace_res).astype(dtype) +def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): + """ + Returns numbers spaced evenly on a log scale (a geometric progression). + + This is similar to logspace, but with endpoints specified directly. Each output sample + is a constant multiple of the previous. + + Args: + start (Union[int, list(int), tuple(int), tensor]): The starting value of the sequence. + stop (Union[int, list(int), tuple(int), tensor]): The final value of the sequence, + unless endpoint is False. In that case, num + 1 values are spaced over the + interval in log-space, of which all but the last (a sequence of length num) are + returned. + num (int, optional): Number of samples to generate. Default is 50. + endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is + not included. Default is True. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can + be in format of np.float32, or `float32`.If `dtype` is None, infer the data + type from other input arguments. Default is None. + axis (int, optional): The axis in the result to store the samples. Relevant + only if start or stop is array-like. By default (0), the samples will + be along a new axis inserted at the beginning. Use -1 to get an axis at the end. + Default is 0. + + Returns: + Tensor, with samples equally spaced on a log scale. + + Raises: + TypeError: If input arguments have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.geomspace(1, 256, num=9) + >>> print(output) + [ 1. 2. 4. 8. 16. 32. 64. 128. 256.] + >>> output = np.geomspace(1, 256, num=8, endpoint=False) + >>> print(output) + [ 1. 2. 4. 8. 16. 32. 64. 128.] + """ + start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis) + root = num + if endpoint: + root -= 1 + bases = F.tensor_pow(F.tensor_div(stop, start), asarray_const(1/(root))) + exponents = linspace(zeros(F.shape(bases)), F.fill(F.dtype(bases), F.shape(bases), root), + num, endpoint=endpoint, dtype=dtype, axis=axis) + shape = F.shape(bases) + axis = axis + F.rank(bases) + 1 if axis < 0 else axis + expanded_shape = _tuple_getitem(shape, axis, False) + (1,) + _tuple_getitem(shape, axis) + bases = F.reshape(bases, expanded_shape) + start = F.reshape(start, expanded_shape) + res = F.tensor_mul(F.tensor_pow(bases, exponents), start) + if dtype is not None: + res = F.cast(res, dtype) + return res + + def eye(N, M=None, k=0, dtype=mstype.float32): """ Returns a 2-D tensor with ones on the diagnoal and zeros elsewhere. @@ -757,7 +828,7 @@ def empty_like(prototype, dtype=None, shape=None): Examples: >>> import mindspore.numpy as np - >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] + >>> a = np.ones((4,1,2)) >>> output = np.empty_like(a) >>> print(output) # result may vary @@ -794,7 +865,7 @@ def ones_like(a, dtype=None, shape=None): Examples: >>> import mindspore.numpy as np - >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] + >>> a = np.ones((4,1,2)) >>> output = np.ones_like(a) >>> print(output) [[[1. 1.]] @@ -832,7 +903,7 @@ def zeros_like(a, dtype=None, shape=None): Examples: >>> import mindspore.numpy as np - >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] + >>> a = np.ones((4,1,2)) >>> output = np.zeros_like(a) >>> print(output) [[[0. 0.]] @@ -871,7 +942,7 @@ def full_like(a, fill_value, dtype=None, shape=None): Examples: >>> import mindspore.numpy as np - >>> a = [[(1, 2)], np.ones((1, 2)), [[2, 3]], np.ones((1, 2))] + >>> a = np.ones((4,1,2)) >>> output = np.full_like(a, 0.5) >>> print(output) [[[0.5 0.5]] @@ -1175,9 +1246,8 @@ def _index(i, size, Cartesian=True): if Cartesian: if i == 1: return 0 - if i == 0: - if size >= 2: - return 1 + if i == 0 and size >= 2: + return 1 return i @@ -1630,3 +1700,103 @@ def ix_(*args): return _raise_value_error('Cross index must be 1 dimensional') res += (F.reshape(arr, _expanded_shape(ndim, arr.size, i)),) return res + + +def vander(x, N=None, increasing=False): + """ + Generates a Vandermonde matrix. + + The columns of the output matrix are powers of the input vector. The order of + the powers is determined by the increasing boolean argument. Specifically, when + increasing is `False`, the i-th output column is the input vector raised element-wise + to the power of :math:`N - i - 1`. Such a matrix with a geometric progression in each row + is named for Alexandre-Theophile Vandermonde. + + Args: + x (Union[list, tuple, Tensor]): 1-D input array. + N (int, optional): Number of columns in the output. If N is not specified, a + square array is returned (``N = len(x)``). + increasing (bool, optional): Order of the powers of the columns. If True, the + powers increase from left to right, if False (the default) they are reversed. + + Returns: + Vandermonde matrix. If `increasing` is `False`, the first column is :math:`x^{(N-1)}`, + the second :math:`x^{(N-2)}` and so forth. If `increasing` is `True`, the columns are + :math:`x^0, x^1, ..., x^{(N-1)}`. + + Raises: + TypeError: If inputs have types not specified above. + ValueError: If `x` is not 1-D, or `N` < 0. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> print(np.vander([1,2,3,4,5])) + [[ 1 1 1 1 1] + [ 16 8 4 2 1] + [ 81 27 9 3 1] + [256 64 16 4 1] + [625 125 25 5 1]] + """ + if isinstance(x, (list, tuple)): + x = asarray_const(x) + elif not isinstance(x, Tensor): + _raise_type_error("Input x must be list, tuple or Tensor, but got ", x) + if x.ndim != 1: + _raise_value_error("Input x must be 1-D, but got dimension=", x.ndim) + N = N or x.size + if not isinstance(N, int): + _raise_type_error("Input N must be an integer.") + if N <= 0: + _raise_value_error("Input N must > 0.") + if not isinstance(increasing, bool): + _raise_type_error("increasing must be a bool.") + exponent = _iota(x.dtype, N, increasing) + x = F.expand_dims(x, 1) + exponent = F.expand_dims(exponent, 0) + return F.tensor_pow(x, exponent) + + +def indices(dimensions, dtype=mstype.int32, sparse=False): + """ + Returns an array representing the indices of a grid. + + Computes an array where the subarrays contain index values 0, 1, … + varying only along the corresponding axis. + + Args: + dimensions (tuple or list of ints): The shape of the grid. + dtype (data type, optional): Data type of the result. + sparse (boolean, optional): Defaults to False. Return a sparse + representation of the grid instead of a dense representation. + + Returns: + Tensor or tuple of Tensor, If `sparse` is False, returns one array + of grid indices, ``grid.shape = (len(dimensions),) + tuple(dimensions)``. + If sparse is True, returns a tuple of arrays, with + ``grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)`` with + ``dimensions[i]`` in the `ith` place + + Raises: + TypeError: if input dimensions is not a tuple or list. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> grid = np.indices((2, 3)) + >>> print(indices) + [Tensor(shape=[2, 3], dtype=Int32, value= + [[0, 0, 0], + [1, 1, 1]]), Tensor(shape=[2, 3], dtype=Int32, value= + [[0, 1, 2], + [0, 1, 2]])] + """ + if not isinstance(dimensions, (tuple, list)): + _raise_type_error('Shape of the grid must be tuple or list') + grids = () + for d in dimensions: + grids += (arange(d, dtype=dtype),) + return meshgrid(*grids, sparse=sparse, indexing='ij') diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index 40ef8f6f7d..d76d26ff25 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -24,62 +24,19 @@ from ..ops.primitive import constexpr from ..nn import Cell from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \ - _check_input_tensor, _broadcast_to + _check_input_tensor, _broadcast_to, _to_tensor from .utils_const import _check_axes_range, _check_start_normalize, \ _raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \ _check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \ _check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \ _list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \ - _tuple_getitem, _expanded_shape + _tuple_getitem, _expanded_shape, _seq_prod, _get_device, _tuple_setitem # According to official numpy reference, the dimension of a numpy array must be less # than 32 MAX_NUMPY_DIMS = 32 -@constexpr -def _prepare_shape_for_expand_dims(shape, axes): - """ - Creates the expanded new shape based on the shape and given axes - - Args: - shape (tuple): the shape of the tensor - axes Union(int, tuple(int), list(int)): the axes with dimensions expanded. - - Returns: - new_shape(tuple): the shape with dimensions expanded. - """ - - new_shape = [] - shape_idx = 0 - new_shape_length = len(shape) - - # Convert to set - if isinstance(axes, int): - new_shape_length += 1 - if axes >= new_shape_length or axes < -new_shape_length: - raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {new_shape_length}") - axes = {axes} - - elif isinstance(axes, (list, tuple)): - new_shape_length += len(axes) - for axis in axes: - if axis >= new_shape_length or axis < -new_shape_length: - raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {new_shape_length}") - axes = set(axes) - - else: - raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}") - - for new_shape_idx in range(new_shape_length): - if new_shape_idx in axes or new_shape_idx - new_shape_length in axes: - new_shape.append(1) - else: - new_shape.append(shape[shape_idx]) - shape_idx += 1 - return tuple(new_shape) - - def expand_dims(a, axis): """ Expands the shape of a tensor. @@ -109,10 +66,15 @@ def expand_dims(a, axis): (1, 2, 2) """ _check_input_tensor(a) - shape = F.shape(a) - # yield expanded shape based on the axes - new_shape = _prepare_shape_for_expand_dims(shape, axis) - return F.reshape(a, new_shape) + if not isinstance(axis, (int, tuple, list)): + _raise_type_error("axis must be tuple, list or int, but got ", axis) + if isinstance(axis, int): + return F.expand_dims(a, axis) + ndim = a.ndim + len(axis) + axis = _canonicalize_axis(axis, ndim) + for ax in axis: + a = F.expand_dims(a, ax) + return a def squeeze(a, axis=None): @@ -1091,6 +1053,9 @@ def roll(a, shift, axis=None): Returns: Tensor, with the same shape as a. + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + Raises: TypeError: If input arguments have types not specified above. ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast. @@ -1212,12 +1177,6 @@ def moveaxis(a, source, destination): return F.transpose(a, perm) -@constexpr -def _seq_prod(seq1, seq2): - """Returns the element-wise product of seq1 and seq2.""" - return tuple(map(lambda x, y: x*y, seq1, seq2)) - - def tile(a, reps): """ Constructs an array by repeating `a` the number of times given by `reps`. @@ -1355,6 +1314,60 @@ def broadcast_arrays(*args): return res +def array_split(x, indices_or_sections, axis=0): + """ + Splits a tensor into multiple sub-tensors. + + Note: + Currently, array_split only supports :class:`mindspore.float32` on ``CPU``. + + The only difference between ``np.split`` and ``np.array_split`` is that + ``np.array_split`` allows indices_or_sections to be an integer that does not + equally divide the axis. For a tensor of length l that should be split into + n sections, it returns :math:`l % n` sub-arrays of size :math:`l//n + 1` and + the rest of size :math:`l//n`. + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): + If integer, :math:`N`, the tensor will be divided into + :math:`N` tensors along axis. + If tuple(int), list(int) or of sorted integers, + the entries indicate where along axis the array is split. + For example, :math:`[2, 3]` would, for :math:`axis=0`, result in + three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`. + If an index exceeds the dimension of the array along axis, + an empty sub-array is returned correspondingly. + axis (int): The axis along which to split. Default: 0. + + Returns: + A list of sub-tensors. + + Raises: + TypeError: If argument `indices_or_sections` is not integer, + tuple(int) or list(int) or argument `axis` is not integer. + ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.arange(9).astype("float32") + >>> output = np.array_split(input_x, 4) + >>> print(output) + (Tensor(shape=[3], dtype=Float32, + value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]), + Tensor(shape=[2], dtype=Float32, + value= [ 3.00000000e+00, 4.00000000e+00]), + Tensor(shape=[2], dtype=Float32, + value= [ 5.00000000e+00, 6.00000000e+00]), + Tensor(shape=[2], dtype=Float32, + value= [ 7.00000000e+00, 8.00000000e+00])) + """ + return _split(x, indices_or_sections, opname="array_split", axis=axis) + + def split(x, indices_or_sections, axis=0): """ Splits a tensor into multiple sub-tensors along the given axis. @@ -1380,9 +1393,12 @@ def split(x, indices_or_sections, axis=0): tuple(int) or list(int) or argument `axis` is not integer. ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`. + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + Examples: >>> import mindspore.numpy as np - >>> input_x = np.arange(9).astype('float32') + >>> input_x = np.arange(9).astype("float32") >>> output = np.split(input_x, 3) >>> print(output) (Tensor(shape=[3], dtype=Float32, @@ -1392,13 +1408,32 @@ def split(x, indices_or_sections, axis=0): Tensor(shape=[3], dtype=Float32, value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) """ + return _split(x, indices_or_sections, opname="split", axis=axis) + + +def _split(x, indices_or_sections, opname, axis=0): + """Splits a tensor based on ``np.split`` or ``np.array_split``.""" _check_input_tensor(x) _ = _check_axis_type(axis, True, False, False) axis = _canonicalize_axis(axis, x.ndim) res = None + arr_shape = x.shape + length_along_dim = arr_shape[axis] if isinstance(indices_or_sections, int): - _split = P.Split(axis, indices_or_sections) - res = _split(x) + if opname == "split" or length_along_dim % indices_or_sections == 0: + res = P.Split(axis, indices_or_sections)(x) + else: + num_long_tensor = length_along_dim % indices_or_sections + num_short_tensor = indices_or_sections - num_long_tensor + length1 = num_long_tensor * (length_along_dim // indices_or_sections + 1) + length2 = length_along_dim - length1 + start1 = _list_comprehensions(F.rank(x), 0, True) + size1 = _tuple_setitem(arr_shape, axis, length1) + start2 = _tuple_setitem(start1, axis, length1) + size2 = _tuple_setitem(arr_shape, axis, length2) + res = P.Split(axis, num_long_tensor)(F.tensor_slice(x, start1, size1)) + \ + P.Split(axis, num_short_tensor)(F.tensor_slice(x, start2, size2)) + elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections): res = _split_sub_tensors(x, indices_or_sections, axis) else: @@ -1921,7 +1956,6 @@ def repeat(a, repeats, axis=None): if repeats == 0: return _empty(F.dtype(a), (0,)) return C.repeat_elements(a, repeats, axis) - shape = F.shape(a) size = shape[axis] if len(repeats) != size: @@ -1932,3 +1966,144 @@ def repeat(a, repeats, axis=None): if rep != 0: repeated_subs.append(C.repeat_elements(sub, rep, axis)) return concatenate(repeated_subs, axis) + + +def rot90(a, k=1, axes=(0, 1)): + """ + Rotates a tensor by 90 degrees in the plane specified by axes. + Rotation direction is from the first towards the second axis. + + Args: + a (Tensor): Input tensor of two or more dimensions. + k (int): Number of times the tensor is rotated by 90 degrees. Default: 1. + axes (Union[tuple(int), list(int)]): The tensor is rotated in the plane + defined by the axes. Default: `(0, 1)`. + Axes must be different and with the shape of `(2,)`. + + Returns: + Tensor. + + Raises: + TypeError: if input `a` is not a Tensor or + the argument `k` is not integer or + the argument `axes` is not tuple of ints or list of ints. + ValueError: if any axis is out of range or + the length of `axes` is not `2`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.arange(24).reshape((2, 3, 4)) + >>> output = np.rot90(a) + >>> print(output) + [[[ 8 9 10 11] + [20 21 22 23]] + [[ 4 5 6 7] + [16 17 18 19]] + [[ 0 1 2 3] + [12 13 14 15]]] + >>> output = np.rot90(a, 3, (1, 2)) + >>> print(output) + [[[ 8 4 0] + [ 9 5 1] + [10 6 2] + [11 7 3]] + [[20 16 12] + [21 17 13] + [22 18 14] + [23 19 15]]] + """ + _check_input_tensor(a) + + if not isinstance(k, int): + _raise_type_error("integer argument expected, but got ", k) + k = k % 4 if k >= 0 else 4 - (-k % 4) + + if not isinstance(axes, (tuple, list)): + _raise_type_error("tuple(ints) or list(ints) expected, but got ", axes) + if len(axes) != 2: + _raise_value_error("len(axes) must be 2.") + axis1, axis2 = axes[0], axes[1] + axis1 = _canonicalize_axis(axis1, a.ndim) + axis2 = _canonicalize_axis(axis2, a.ndim) + if axis1 == axis2: + _raise_value_error('Axes must be different.') + + if k == 0: + return a + if k == 2: + return flip(flip(a, axis1), axis2) + perm = _list_comprehensions(a.ndim) + perm[axis1], perm[axis2] = perm[axis2], perm[axis1] + if k == 1: + return flip(transpose(a, perm), axis1) + return flip(transpose(a, perm), axis2) + + +def select(condlist, choicelist, default=0): + """ + Returns an array drawn from elements in `choicelist`, depending on conditions. + + Args: + condlist (array_like): The list of conditions which determine from which array + in `choicelist` the output elements are taken. When multiple conditions are + satisfied, the first one encountered in `condlist` is used. + choicelist (array_like): The list of arrays from which the output elements are + taken. It has to be of the same length as `condlist`. + default (scalar, optional): The element inserted in output when all conditions + evaluate to `False`. + + Returns: + Tensor, the output at position `m` is the `m-th` element of the array in + `choicelist` where the `m-th` element of the corresponding array in `condlist` + is `True`. + + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Raises: + ValueError: if ``len(condlist) != len(choicelist)``. + + Examples: + >>> condlist = [[True, True, True, False, False], + [False, False, True, False, True]] + >>> choicelist = [[0, 1, 2, 3, 4], [0, 1, 4, 9, 16]] + >>> output = np.select(condlist, choicelist) + >>> print(output) + [ 0 1 2 0 16] + """ + condlist, choicelist = _to_tensor(condlist, choicelist) + shape_cond = F.shape(condlist) + shape_choice = F.shape(choicelist) + if F.rank(condlist) == 0 or F.rank(condlist) == 0: + _raise_value_error('input cannot be scalars') + case_num = shape_cond[0] + if shape_choice[0] != case_num: + _raise_value_error('list of cases must be same length as list of conditions') + + # performs broadcast over the cases in condlist and choicelist + case_size = _infer_out_shape(shape_cond[1:], shape_choice[1:]) + shape_broadcasted = (case_num,) + case_size + ndim = len(shape_broadcasted) + shape_cond_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(condlist), 1, True) + + shape_cond[1:]) + condlist = _broadcast_to_shape(F.reshape(condlist, shape_cond_expanded), shape_broadcasted) + shape_choice_expanded = ((case_num,) + _list_comprehensions(ndim - F.rank(choicelist), 1, True) + + shape_choice[1:]) + choicelist = _broadcast_to_shape(F.reshape(choicelist, shape_choice_expanded), shape_broadcasted) + + slice_start = _list_comprehensions(ndim - 1, 0, True) + slice_size = (1,) + case_size + dtype = F.dtype(choicelist) + if _get_device() == 'CPU' and not _check_is_float(dtype): + # F.tensor_slice only supports float on CPU + choicelist = F.cast(choicelist, mstype.float32) + default_slice = F.fill(F.dtype(choicelist), slice_size, default) + for i in range(case_num - 1, -1, -1): + cond_slice = F.tensor_slice(condlist.astype(mstype.float32), (i,) + slice_start, slice_size) + choice_slice = F.tensor_slice(choicelist, (i,) + slice_start, slice_size) + default_slice = F.select(cond_slice.astype(mstype.bool_), choice_slice, default_slice) + return F.reshape(default_slice, (case_size)).astype(dtype) diff --git a/mindspore/numpy/dtypes.py b/mindspore/numpy/dtypes.py index 1c1a946a1f..a0ad8a668c 100644 --- a/mindspore/numpy/dtypes.py +++ b/mindspore/numpy/dtypes.py @@ -169,3 +169,16 @@ promotion_rule = { (bool_, float32): float32, (bool_, float64): float64, } + +rule_for_trigonometric = {float16: float16, + float32: float32, + float64: float64, + int8: float16, + int16: float32, + int32: float32, + int64: float32, + uint8: float16, + uint16: float32, + uint32: float32, + uint64: float32, + bool_: float16} diff --git a/mindspore/numpy/logic_ops.py b/mindspore/numpy/logic_ops.py index aab930e378..16be5043b0 100644 --- a/mindspore/numpy/logic_ops.py +++ b/mindspore/numpy/logic_ops.py @@ -15,33 +15,29 @@ """logical operations, the function docs are adapted from Numpy API.""" -from .math_ops import _apply_tensor_op from ..ops import functional as F from ..ops.primitive import constexpr from ..common import dtype as mstype from ..common import Tensor from .._c_expression import typing -from .array_creations import zeros, ones -from .utils import _check_input_tensor +from .math_ops import _apply_tensor_op, absolute +from .array_creations import zeros, ones, empty +from .utils import _check_input_tensor, _to_tensor, _isnan +from .utils_const import _raise_type_error, _is_shape_empty, _infer_out_shape -def not_equal(x1, x2, out=None, where=True, dtype=None): +def not_equal(x1, x2, dtype=None): """ Returns (x1 != x2) element-wise. + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. + Args: x1 (Tensor): First input tensor to be compared. x2 (Tensor): Second input tensor to be compared. - out (Tensor or None, optional), default is None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -65,33 +61,21 @@ def not_equal(x1, x2, out=None, where=True, dtype=None): [False True]] """ _check_input_tensor(x1, x2) - return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.not_equal, x1, x2, dtype=dtype) -def less_equal(x1, x2, out=None, where=True, dtype=None): +def less_equal(x1, x2, dtype=None): """ Returns the truth value of ``(x1 <= x2)`` element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Args: x1 (Tensor): Input array. x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -113,33 +97,21 @@ def less_equal(x1, x2, out=None, where=True, dtype=None): [False True True] """ _check_input_tensor(x1, x2) - return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_le, x1, x2, dtype=dtype) -def less(x1, x2, out=None, where=True, dtype=None): +def less(x1, x2, dtype=None): """ Returns the truth value of ``(x1 < x2)`` element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Args: x1 (Tensor): input array. x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -160,33 +132,21 @@ def less(x1, x2, out=None, where=True, dtype=None): >>> print(output) [ True False] """ - return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_lt, x1, x2, dtype=dtype) -def greater_equal(x1, x2, out=None, where=True, dtype=None): +def greater_equal(x1, x2, dtype=None): """ Returns the truth value of ``(x1 >= x2)`` element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Args: x1 (Tensor): Input array. x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -207,33 +167,21 @@ def greater_equal(x1, x2, out=None, where=True, dtype=None): >>> print(output) [ True True False] """ - return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_ge, x1, x2, dtype=dtype) -def greater(x1, x2, out=None, where=True, dtype=None): +def greater(x1, x2, dtype=None): """ Returns the truth value of ``(x1 > x2)`` element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Args: x1 (Tensor): Input array. x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -254,33 +202,21 @@ def greater(x1, x2, out=None, where=True, dtype=None): >>> print(output) [ True False] """ - return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_gt, x1, x2, dtype=dtype) -def equal(x1, x2, out=None, where=True, dtype=None): +def equal(x1, x2, dtype=None): """ Returns the truth value of ``(x1 == x2)`` element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Args: x1 (Tensor): Input array. x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -301,34 +237,22 @@ def equal(x1, x2, out=None, where=True, dtype=None): >>> print(output) [ True True False] """ - return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.equal, x1, x2, dtype=dtype) -def isfinite(x, out=None, where=True, dtype=None): +def isfinite(x, dtype=None): """ Tests element-wise for finiteness (not infinity or not Not a Number). The result is returned as a boolean array. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. On GPU, the supported dtypes are np.float16, and np.float32. Args: x (Tensor): Input values. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -351,37 +275,20 @@ def isfinite(x, out=None, where=True, dtype=None): >>> print(output) False """ - return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype) - - -def _isnan(x): - """Computes isnan without applying keyword arguments.""" - return F.not_equal(x, x) + return _apply_tensor_op(F.isfinite, x, dtype=dtype) -def isnan(x, out=None, where=True, dtype=None): +def isnan(x, dtype=None): """ Tests element-wise for NaN and return result as a boolean array. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Only np.float32 is currently supported. Args: x (Tensor): Input values. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -404,7 +311,7 @@ def isnan(x, out=None, where=True, dtype=None): >>> print(output) False """ - return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(_isnan, x, dtype=dtype) def _isinf(x): @@ -419,31 +326,19 @@ def _isinf(x): return F.cast(res, mstype.bool_) -def isinf(x, out=None, where=True, dtype=None): +def isinf(x, dtype=None): """ Tests element-wise for positive or negative infinity. Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are - not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. Only np.float32 is currently supported. Args: x (Tensor): Input values. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -466,7 +361,7 @@ def isinf(x, out=None, where=True, dtype=None): >>> print(output) [ True True False False] """ - return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(_isinf, x, dtype=dtype) def _is_sign_inf(x, fn): @@ -562,7 +457,7 @@ def isscalar(element): element (any): Input argument, can be of any type and shape. Returns: - Boolean, True if `element` is a scalar type, False if it is not. + Boolean, True if `element` is a scalar type, False if it is not. Raises: TypeError: if the type of `element` is not supported by mindspore parser. @@ -587,3 +482,302 @@ def isscalar(element): """ obj_type = F.typeof(element) return not isinstance(obj_type, Tensor) and _isscalar(obj_type) + + +def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + """ + Returns a boolean tensor where two tensors are element-wise equal within a tolerance. + + The tolerance values are positive, typically very small numbers. The relative + difference (:math:`rtol * abs(b)`) and the absolute difference `atol` are added together + to compare against the absolute difference between `a` and `b`. + + Note: + For finite values, isclose uses the following equation to test whether two + floating point values are equivalent. + :math:`absolute(a - b) <= (atol + rtol * absolute(b))` + + Args: + a (Union[Tensor, list, tuple]): Input first tensor to compare. + b (Union[Tensor, list, tuple]): Input second tensor to compare. + rtol (Number): The relative tolerance parameter (see Note). + atol (Number): The absolute tolerance parameter (see Note). + equal_nan (bool): Whether to compare ``NaN`` as equal. If True, ``NaN`` in + `a` will be considered equal to ``NaN`` in `b` in the output tensor. + + Returns: + A ``bool`` tensor of where `a` and `b` are equal within the given tolerance. + + Raises: + TypeError: If inputs have types not specified above. + + Supported Platforms: + ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([0,1,2,float('inf'),float('inf'),float('nan')]) + >>> b = np.array([0,1,-2,float('-inf'),float('inf'),float('nan')]) + >>> print(np.isclose(a, b)) + [ True True False False True False] + >>> print(np.isclose(a, b, equal_nan=True)) + [ True True False False True True] + """ + a, b = _to_tensor(a, b) + if not isinstance(rtol, (int, float, bool)) or not isinstance(atol, (int, float, bool)): + _raise_type_error("rtol and atol are expected to be numbers.") + if not isinstance(equal_nan, bool): + _raise_type_error("equal_nan is expected to be bool.") + + if _is_shape_empty(a.shape) or _is_shape_empty(b.shape): + return empty(_infer_out_shape(a.shape, b.shape), dtype=mstype.bool_) + rtol = _to_tensor(rtol).astype("float32") + atol = _to_tensor(atol).astype("float32") + res = absolute(a - b) <= (atol + rtol * absolute(b)) + # infs are treated as equal + a_posinf = isposinf(a) + b_posinf = isposinf(b) + a_neginf = isneginf(a) + b_neginf = isneginf(b) + same_inf = F.logical_or(F.logical_and(a_posinf, b_posinf), F.logical_and(a_neginf, b_neginf)) + diff_inf = F.logical_or(F.logical_and(a_posinf, b_neginf), F.logical_and(a_neginf, b_posinf)) + res = F.logical_and(F.logical_or(res, same_inf), F.logical_not(diff_inf)) + both_nan = F.logical_and(_isnan(a), _isnan(b)) + if equal_nan: + res = F.logical_or(both_nan, res) + else: + res = F.logical_and(F.logical_not(both_nan), res) + return res + + +def in1d(ar1, ar2, invert=False): + """ + Tests whether each element of a 1-D array is also present in a second array. + + Returns a boolean array the same length as `ar1` that is True where an element + of `ar1` is in `ar2` and False otherwise. + + Note: + Numpy argument `assume_unique` is not supported since the implementation does + not rely on the uniqueness of the input arrays. + + Args: + ar1 (array_like): Input array with shape `(M,)`. + ar2 (array_like): The values against which to test each value of `ar1`. + invert (boolean, optional): If True, the values in the returned array are + inverted (that is, False where an element of `ar1` is in `ar2` and True + otherwise). Default is False. + + Returns: + Tensor, with shape `(M,)`. The values ``ar1[in1d]`` are in `ar2`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> test = np.array([0, 1, 2, 5, 0]) + >>> states = [0, 2] + >>> mask = np.in1d(test, states) + >>> print(mask) + [ True False True False True] + >>> mask = np.in1d(test, states, invert=True) + >>> print(mask) + [False True False True False] + """ + ar1, ar2 = _to_tensor(ar1, ar2) + ar1 = F.expand_dims(ar1.ravel(), -1) + ar2 = ar2.ravel() + included = F.equal(ar1, ar2) + # F.reduce_sum only supports float + res = F.reduce_sum(included.astype(mstype.float32), -1).astype(mstype.bool_) + if invert: + res = F.equal(res, _to_tensor(False)) + return res + + +def isin(element, test_elements, invert=False): + """ + Calculates element in `test_elements`, broadcasting over `element` only. Returns a + boolean array of the same shape as `element` that is True where an element of + `element` is in `test_elements` and False otherwise. + + Note: + Numpy argument `assume_unique` is not supported since the implementation does + not rely on the uniqueness of the input arrays. + + Args: + element (array_like): Input array. + test_elements (array_like): The values against which to test each value of + `element`. + invert (boolean, optional): If True, the values in the returned array are + inverted, as if calculating `element` not in `test_elements`. Default is False. + + Returns: + Tensor, has the same shape as `element`. The values ``element[isin]`` are in + `test_elements`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> element = 2*np.arange(4).reshape((2, 2)) + >>> test_elements = [1, 2, 4, 8] + >>> mask = np.isin(element, test_elements) + >>> print(mask) + [[False True] + [ True False]] + >>> mask = np.isin(element, test_elements, invert=True) + >>> print(mask) + [[ True False] + [False True]] + """ + res = in1d(element, test_elements, invert=invert) + return F.reshape(res, F.shape(element)) + + +def logical_not(a, dtype=None): + """ + Computes the truth value of NOT `a` element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + a (Tensor): The input tensor whose dtype is bool. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. + Boolean result with the same shape as `a` of the NOT operation on elements of `a`. + This is a scalar if `a` is a scalar. + + Raises: + TypeError: if the input is not a tensor or its dtype is not bool. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.array([True, False]) + >>> output = np.logical_not(a) + >>> print(output) + [False True] + """ + return _apply_tensor_op(F.logical_not, a, dtype=dtype) + + +def logical_or(x1, x2, dtype=None): + """ + Computes the truth value of `x1` OR `x2` element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. + + Args: + x1 (Tensor): Input tensor. + x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.array([True, False]) + >>> x2 = np.array([False, True]) + >>> output = np.logical_or(x1, x2) + >>> print(output) + [ True True] + """ + return _apply_tensor_op(F.logical_or, x1, x2, dtype=dtype) + + +def logical_and(x1, x2, dtype=None): + """ + Computes the truth value of `x1` AND `x2` element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. + + Args: + x1 (Tensor): Input tensor. + x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. + Boolean result of the logical AND operation applied to the elements of `x1` and `x2`; + the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.array([True, False]) + >>> x2 = np.array([False, False]) + >>> output = np.logical_and(x1, x2) + >>> print(output) + [False False] + """ + return _apply_tensor_op(F.logical_and, x1, x2, dtype=dtype) + + +def logical_xor(x1, x2, dtype=None): + """ + Computes the truth value of `x1` XOR `x2`, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, + and `extobj` are not supported. + + Args: + x1 (Tensor): Input tensor. + x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. + Boolean result of the logical AND operation applied to the elements of `x1` and `x2`; + the shape is determined by broadcasting. This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.array([True, False]) + >>> x2 = np.array([False, False]) + >>> output = np.logical_xor(x1, x2) + >>> print(output) + [True False] + """ + _check_input_tensor(x1) + _check_input_tensor(x2) + y1 = F.logical_or(x1, x2) + y2 = F.logical_or(F.logical_not(x1), F.logical_not(x2)) + return _apply_tensor_op(F.logical_and, y1, y2, dtype=dtype) diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 17d30d7956..158e24d1e7 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -27,20 +27,21 @@ from .dtypes import nan, pi from .array_creations import asarray_const, ones, zeros, empty, full, full_like from .array_ops import where as where_ -from .array_ops import ravel, expand_dims +from .array_ops import ravel, expand_dims, moveaxis, concatenate from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ _raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \ - _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range -from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ - _check_input_tensor + _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \ + _check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \ + _make_tensor, _promote_for_trigonometric, _raise_runtime_error, _max +from .utils import _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ + _check_input_tensor, _to_tensor, _isnan ZERO_TENSOR = asarray_const(0) -_mean_default = P.ReduceMean() _mean_keepdims = P.ReduceMean(True) _matmul = P.MatMul(False, False) _matmul_T = P.MatMul(False, True) @@ -51,31 +52,20 @@ _reduce_min_keepdims = P.ReduceMin(True) _reduce_max_default = P.ReduceMax() _reduce_max_keepdims = P.ReduceMax(True) _cumsum_default = P.CumSum() +_concat = P.Concat(-1) -def absolute(x, out=None, where=True, dtype=None): +def absolute(x, dtype=None): """ Calculates the absolute value element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Currently the backend kernel only supports float calculation, if the input is not a `float`, then it will be casted to :class:`mstype.float32` and casted back. Args: x (Tensor): Tensor to be used for calculation. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -98,8 +88,8 @@ def absolute(x, out=None, where=True, dtype=None): original_dtype = x.dtype if not _check_is_float(original_dtype) and dtype is None: x = x.astype(mstype.float32) - return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype).astype(original_dtype) - return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.absolute, x, dtype=dtype).astype(original_dtype) + return _apply_tensor_op(F.absolute, x, dtype=dtype) def count_nonzero(x, axis=None, keepdims=False): @@ -139,7 +129,7 @@ def count_nonzero(x, axis=None, keepdims=False): return C.count_nonzero(x=x, axis=axis, keep_dims=keepdims) -def clip(x, xmin, xmax, out=None, where=True, dtype=None): +def clip(x, xmin, xmax, dtype=None): """ Clips (limits) the values in an array. @@ -155,15 +145,6 @@ def clip(x, xmin, xmax, out=None, where=True, dtype=None): on upper interval edge. Not more than one of `xmin` and `xmax` may be None. If `xmin` or `xmax` are tensors, then the three tensors will be broadcasted to match their shapes. - out (Tensor or None): optional, default to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -184,27 +165,18 @@ def clip(x, xmin, xmax, out=None, where=True, dtype=None): if xmin is None and xmax is None: _raise_value_error("One of max or min must be given.") if xmin is not None: - x = maximum(x, xmin, out=out, where=where, dtype=dtype) + x = maximum(x, xmin, dtype=dtype) if xmax is not None: - x = minimum(x, xmax, out=out, where=where, dtype=dtype) + x = minimum(x, xmax, dtype=dtype) return x -def deg2rad(x, out=None, where=True, dtype=None): +def deg2rad(x, dtype=None): """ Converts angles from degrees to radians. Args: x (Tensor): Angles in degrees. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -229,24 +201,15 @@ def deg2rad(x, out=None, where=True, dtype=None): def convert(a): return a * pi / 180.0 - return _apply_tensor_op(convert, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(convert, x, dtype=dtype) -def rad2deg(x, out=None, where=True, dtype=None): +def rad2deg(x, dtype=None): """ Converts angles from radians to degrees. Args: x (Tensor): Angles in radians. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -271,32 +234,20 @@ def rad2deg(x, out=None, where=True, dtype=None): def convert(a): return a * 180.0 / pi - return _apply_tensor_op(convert, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(convert, x, dtype=dtype) -def add(x1, x2, out=None, where=True, dtype=None): +def add(x1, x2, dtype=None): """ Adds arguments element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): input to be added. x2 (Tensor): input to be added. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -324,33 +275,21 @@ def add(x1, x2, out=None, where=True, dtype=None): # so we use tensor_sub as a substitute solution if _get_device() == 'CPU': _check_input_tensor(x1, x2) - return subtract(x1, F.neg_tensor(x2), out=out, where=where, dtype=dtype) - return _apply_tensor_op(F.tensor_add, x1, x2, out=out, where=where, dtype=dtype) + return subtract(x1, F.neg_tensor(x2), dtype=dtype) + return _apply_tensor_op(F.tensor_add, x1, x2, dtype=dtype) -def subtract(x1, x2, out=None, where=True, dtype=None): +def subtract(x1, x2, dtype=None): """ Subtracts arguments, element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): the input to be subtracted from. x2 (Tensor): the input to be subtracted by. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -374,32 +313,20 @@ def subtract(x1, x2, out=None, where=True, dtype=None): [-2, -2], [-2, -2]] """ - return _apply_tensor_op(F.tensor_sub, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_sub, x1, x2, dtype=dtype) -def multiply(x1, x2, out=None, where=True, dtype=None): +def multiply(x1, x2, dtype=None): """ Multiplies arguments element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): input tensor to be multiplied. x2 (Tensor): input tensor to be multiplied. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -430,10 +357,10 @@ def multiply(x1, x2, out=None, where=True, dtype=None): shape_out = _infer_out_shape(F.shape(x1), F.shape(x2)) x1 = _broadcast_to_shape(x1, shape_out) x2 = _broadcast_to_shape(x2, shape_out) - return _apply_tensor_op(F.tensor_mul, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_mul, x1, x2, dtype=dtype) -def divide(x1, x2, out=None, where=True, dtype=None): +def divide(x1, x2, dtype=None): """ Returns a true division of the inputs, element-wise. @@ -441,24 +368,12 @@ def divide(x1, x2, out=None, where=True, dtype=None): division. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): the divident. x2 (Tensor): the divisor. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -484,10 +399,10 @@ def divide(x1, x2, out=None, where=True, dtype=None): if not _check_is_float(F.dtype(x1)) and not _check_is_float(F.dtype(x2)): x1 = F.cast(x1, mstype.float32) x2 = F.cast(x2, mstype.float32) - return _apply_tensor_op(F.tensor_div, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_div, x1, x2, dtype=dtype) -def true_divide(x1, x2, out=None, where=True, dtype=None): +def true_divide(x1, x2, dtype=None): """ Returns a true division of the inputs, element-wise. @@ -495,23 +410,12 @@ def true_divide(x1, x2, out=None, where=True, dtype=None): division. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): the divident. x2 (Tensor): the divisor. - out (Tensor or None, optional) - where (Tensor, optional): - This condition is broadcast over the input. At locations where the - condition is True, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -534,35 +438,23 @@ def true_divide(x1, x2, out=None, where=True, dtype=None): [0.33333333, 0.5], [0.33333333, 0.5]] """ - return divide(x1, x2, out=out, where=where, dtype=dtype) + return divide(x1, x2, dtype=dtype) -def power(x1, x2, out=None, where=True, dtype=None): +def power(x1, x2, dtype=None): """ First array elements raised to powers from second array, element-wise. Raises each base in `x1` to the positionally-corresponding power in `x2`. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16, and np.float32. Args: x1 (Tensor): the bases. x2 (Tensor): the exponents. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -586,10 +478,10 @@ def power(x1, x2, out=None, where=True, dtype=None): [ 1, 16], [ 1, 16]] """ - return _apply_tensor_op(F.tensor_pow, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_pow, x1, x2, dtype=dtype) -def float_power(x1, x2, out=None, where=True, dtype=None): +def float_power(x1, x2, dtype=None): """ First array elements raised to powers from second array, element-wise. @@ -601,25 +493,13 @@ def float_power(x1, x2, out=None, where=True, dtype=None): and seldom overflow for positive powers. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Integers and floats are promoted to float32 instead of float64. Args: x1 (Tensor): the bases. x2 (Tensor): the exponenets. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -646,21 +526,18 @@ def float_power(x1, x2, out=None, where=True, dtype=None): if not _check_same_type(F.dtype(x2), mstype.float32): x2 = F.cast(x2, mstype.float32) - return _apply_tensor_op(F.tensor_pow, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_pow, x1, x2, dtype=dtype) -def minimum(x1, x2, out=None, where=True, dtype=None): +def minimum(x1, x2, dtype=None): """ Element-wise minimum of tensor elements. Compares two tensors and returns a new tensor containing the element-wise minima. - Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Unlike numpy, when one of the elements is a NaN, the second element is always returned regardless of whether the second element is a NaN, instead of returning NaN. @@ -668,15 +545,6 @@ def minimum(x1, x2, out=None, where=True, dtype=None): Args: x1 (Tensor): first input tensor to be compared. x2 (Tensor): second input tensor to be compared. - out (Tensor or None, optional), default is None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -712,12 +580,12 @@ def minimum(x1, x2, out=None, where=True, dtype=None): # comparisons with 2 scalars if x1.ndim == 0 and x2.ndim == 0: x1 = expand_dims(x1, 0) - return _apply_tensor_op(F.minimum, x1, x2, out=out, where=where, dtype=dtype).squeeze() + return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype).squeeze() if x1.ndim == 0: dtype = x2.dtype elif x2.ndim == 0: dtype = x1.dtype - return _apply_tensor_op(F.minimum, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.minimum, x1, x2, dtype=dtype) def mean(a, axis=None, keepdims=False, dtype=None): @@ -763,34 +631,7 @@ def mean(a, axis=None, keepdims=False, dtype=None): >>> print(output) 2.5 """ - - axis = _check_axis_valid(axis, F.rank(a)) - shape_a = F.shape(a) - if dtype is None: - dtype = F.dtype(a) - - if _is_shape_empty(shape_a): - if keepdims: - shape_out = _shape_reduced_keepdims(shape_a, axis) - else: - shape_out = _shape_reduced(shape_a, axis) - if _is_shape_empty(shape_out): - return empty(F.dtype(a), shape_out) - return full(shape_out, nan, dtype) - - if _is_scalar(shape_a): - if keepdims: - return a - shape_out = _shape_reduced(shape_a, axis) - return F.reshape(a, shape_out) - - if keepdims: - res = _mean_keepdims(a, axis) - else: - res = _mean_default(a, axis) - if not _check_same_type(dtype, F.dtype(res)): - res = F.cast(res, dtype) - return res + return _reduce(a, P.ReduceMean(keepdims), axis=axis, keepdims=keepdims, dtype=dtype) def inner(a, b): @@ -1132,7 +973,7 @@ def var(x, axis=None, ddof=0, keepdims=False): return F.tensor_pow(x_std, 2) -def ptp(x, axis=None, out=None, keepdims=False): +def ptp(x, axis=None, keepdims=False): """ Range of values (maximum - minimum) along an axis. The name of the function comes from the acronym for ‘peak to peak’. @@ -1213,9 +1054,7 @@ def average(x, axis=None, weights=None, returned=False): Tensor(shape=[2], dtype=Float32, value= [ 4.00000000e+00, 6.00000000e+00])) """ _check_input_tensor(x) - if axis is None: - axis = () - else: + if axis is not None: _check_axis_type(axis, True, True, False) axis = _canonicalize_axis(axis, x.ndim) @@ -1227,9 +1066,9 @@ def average(x, axis=None, weights=None, returned=False): sum_of_weights = full((), x.size, F.dtype(x)) else: fill_value = 1 - if isinstance(axis, int) or isinstance(axis, tuple) and F.tuple_len(axis) == 1: - fill_value = x.shape[axis] - elif axis is None or axis == (): + if isinstance(axis, int) or (isinstance(axis, tuple) and F.tuple_len(axis) == 1): + fill_value = x.shape[axis] if isinstance(axis, int) else x.shape[axis[0]] + elif axis is None: for sh in x.shape: fill_value *= sh else: @@ -1258,6 +1097,7 @@ def average(x, axis=None, weights=None, returned=False): def comput_avg(x, axis, weights): """Computes average value of input x with given parameters.""" + axis = () if axis is None else axis x_mul = F.tensor_mul(x, weights) x_sum = _reduce_sum_default(x_mul, axis) sum_of_weights = _reduce_sum_default(weights, axis) @@ -1308,29 +1148,17 @@ def matmul(x1, x2, dtype=None): return C.matmul(x1, x2, dtype=dtype) -def square(x, out=None, where=True, dtype=None): +def square(x, dtype=None): """ Returns the element-wise square of the input. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16 and np.float32. Args: x (Tensor): Input data. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1351,32 +1179,20 @@ def square(x, out=None, where=True, dtype=None): [[ 0. 1. 4.] [ 9. 16. 25.]] """ - return _apply_tensor_op(F.square, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.square, x, dtype=dtype) -def sqrt(x, out=None, where=True, dtype=None): +def sqrt(x, dtype=None): """ Returns the non-negative square-root of an array, element-wise. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16 and np.float32. Args: x (Tensor): The values whose square-roots are required. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1400,17 +1216,17 @@ def sqrt(x, out=None, where=True, dtype=None): [[ 0. 1. 2.] [ 3. 4. 5.]] """ - return _apply_tensor_op(F.sqrt, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.sqrt, x, dtype=dtype) -def reciprocal(x, out=None, where=True, dtype=None): +def reciprocal(x, dtype=None): """ Returns the reciprocal of the argument, element-wise. Calculates ``1/x``. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. When `where` is provided, `out` must have a tensor value. `out` is not supported for storing the result, however it can be used in combination with `where` to set @@ -1420,15 +1236,6 @@ def reciprocal(x, out=None, where=True, dtype=None): x (Tensor): Input array. For integer arguments with absolute value larger than 1 the result is always zero because of the way Python handles integer division. For integer zero the result is an overflow. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1449,10 +1256,10 @@ def reciprocal(x, out=None, where=True, dtype=None): [[1. 0.5 0.33333334] [0.25 0.2 0.16666667]] """ - return _apply_tensor_op(lambda x: F.tensor_div(1, x), x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(lambda x: F.tensor_div(1, x), x, dtype=dtype) -def log(x, out=None, where=True, dtype=None): +def log(x, dtype=None): """ Returns the natural logarithm, element-wise. @@ -1460,11 +1267,8 @@ def log(x, out=None, where=True, dtype=None): ``log(exp(x)) = x``. The natural logarithm is logarithm in base e. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16, and np.float32. On CPU, the supported dtypes are np.float16, np.float32, and np.float64. @@ -1472,15 +1276,6 @@ def log(x, out=None, where=True, dtype=None): x (Tensor): Input array. For integer arguments with absolute value larger than 1 the result is always zero because of the way Python handles integer division. For integer zero the result is an overflow. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1501,21 +1296,18 @@ def log(x, out=None, where=True, dtype=None): >>> print(output) [0.69314575 1.09861 1.3862929 ] """ - return _apply_tensor_op(F.log, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.log, x, dtype=dtype) -def maximum(x1, x2, out=None, where=True, dtype=None): +def maximum(x1, x2, dtype=None): """ Returns the element-wise maximum of array elements. Compares two arrays and returns a new array containing the element-wise maxima. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Unlike numpy, when one of the elements is a NaN, the second element is always returned regardless of whether the second element is a NaN, instead of returning NaN. @@ -1525,15 +1317,6 @@ def maximum(x1, x2, out=None, where=True, dtype=None): x2 (Tensor): The array holding the elements to be compared. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1566,39 +1349,27 @@ def maximum(x1, x2, out=None, where=True, dtype=None): # F.maximum does not support when both operands are scalar if x1.ndim == 0 and x2.ndim == 0: x1 = expand_dims(x1, 0) - return _apply_tensor_op(F.maximum, x1, x2, out=out, where=where, dtype=dtype).squeeze() + return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype).squeeze() if x1.ndim == 0: dtype = x2.dtype elif x2.ndim == 0: dtype = x1.dtype - return _apply_tensor_op(F.maximum, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.maximum, x1, x2, dtype=dtype) -def heaviside(x1, x2, out=None, where=True, dtype=None): +def heaviside(x1, x2, dtype=None): """ Computes the Heaviside step function. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): Input values. x2 (Tensor): The value of the function when `x1` is 0. If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1642,7 +1413,7 @@ def heaviside(x1, x2, out=None, where=True, dtype=None): x2 = F.select(x1 > 0, ones(shape_out, dtype_out), x2) return x2 - return _apply_tensor_op(_heaviside, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(_heaviside, x1, x2, dtype=dtype) def amax(a, axis=None, keepdims=False, initial=None, where=True): @@ -1697,7 +1468,7 @@ def amax(a, axis=None, keepdims=False, initial=None, where=True): >>> print(output) [-1. 3.] """ - return _reduce(a, P.ReduceMax(keepdims), F.maximum, axis=axis, keepdims=keepdims, + return _reduce(a, P.ReduceMax(keepdims), cmp_fn=F.maximum, axis=axis, keepdims=keepdims, initial=initial, where=where) @@ -1753,11 +1524,11 @@ def amin(a, axis=None, keepdims=False, initial=None, where=True): >>> print(output) [10. 1.] """ - return _reduce(a, P.ReduceMin(keepdims), F.minimum, axis=axis, keepdims=keepdims, + return _reduce(a, P.ReduceMin(keepdims), cmp_fn=F.minimum, axis=axis, keepdims=keepdims, initial=initial, where=where) -def hypot(x1, x2, out=None, where=True, dtype=None): +def hypot(x1, x2, dtype=None): """ Given the “legs” of a right triangle, returns its hypotenuse. @@ -1766,11 +1537,8 @@ def hypot(x1, x2, out=None, where=True, dtype=None): with each element of the other argument. (See Examples) Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16 and np.float32. On CPU, the supported dtypes are np.float16, np.float32, and np.float64. @@ -1779,15 +1547,6 @@ def hypot(x1, x2, out=None, where=True, dtype=None): x2 (Tensor): Leg of the triangle(s). If ``x1.shape != x2.shape``, they must be broadcastable to a common shape (which becomes the shape of the output). - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1823,35 +1582,23 @@ def hypot(x1, x2, out=None, where=True, dtype=None): return F.sqrt(F.tensor_sub(F.square(x1), F.neg_tensor(F.square(x2)))) return F.sqrt(F.tensor_add(F.square(x1), F.square(x2))) - return _apply_tensor_op(_hypot, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(_hypot, x1, x2, dtype=dtype) -def floor(x, out=None, where=True, dtype=None): +def floor(x, dtype=None): """ Returns the floor of the input, element-wise. The floor of the scalar `x` is the largest integer `i`, such that ``i <= x``. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16 and np.float32. On CPU, the supported dtypes are np.float16, np.float32, and np.float64. Args: x (Tensor): input data. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1871,34 +1618,22 @@ def floor(x, out=None, where=True, dtype=None): >>> print(output) [-2. -2. -1. 0. 1. 1. 2.] """ - return _apply_tensor_op(F.floor, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.floor, x, dtype=dtype) -def floor_divide(x1, x2, out=None, where=True, dtype=None): +def floor_divide(x1, x2, dtype=None): """ Returns the largest integer smaller or equal to the division of the inputs. It is equivalent to the Python // operator and pairs with the Python % (remainder), function so that ``a = a % b + b * (a // b)`` up to roundoff. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): Input array. x2 (Tensor): Input array. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1917,7 +1652,7 @@ def floor_divide(x1, x2, out=None, where=True, dtype=None): >>> print(output) [0. 0. 1. 1.] """ - return _apply_tensor_op(F.tensor_floordiv, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_floordiv, x1, x2, dtype=dtype) def _remainder(x1, x2, C_style=False): @@ -1944,7 +1679,7 @@ def _remainder(x1, x2, C_style=False): return res -def remainder(x1, x2, out=None, where=True, dtype=None): +def remainder(x1, x2, dtype=None): """ Returns element-wise remainder of division. @@ -1953,24 +1688,12 @@ def remainder(x1, x2, out=None, where=True, dtype=None): as the divisor `x2`. The MATLAB function equivalent to np.remainder is mod. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor): input array. x2 (Tensor): input array. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -1993,7 +1716,7 @@ def remainder(x1, x2, out=None, where=True, dtype=None): >>> print(output) [0 1 2 3 4 0 1] """ - return _apply_tensor_op(_remainder, x1, x2, out=out, where=where, dtype=dtype) + return _apply_tensor_op(_remainder, x1, x2, dtype=dtype) def fix(x): @@ -2034,7 +1757,7 @@ def fix(x): return F.select(is_neg, ceiled, floored) -def fmod(x1, x2, out=None, where=True, dtype=None): +def fmod(x1, x2, dtype=None): """ Returns the element-wise remainder of division. @@ -2043,24 +1766,12 @@ def fmod(x1, x2, out=None, where=True, dtype=None): function and should not be confused with the Python modulus operator ``x1 % x2``. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x1 (Tensor) x2 (Tensor): input arrays. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -2080,11 +1791,10 @@ def fmod(x1, x2, out=None, where=True, dtype=None): >>> print(output) [-1 0 -1 1 0 1] """ - return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, C_style=True), x1, x2, - out=out, where=where, dtype=dtype) + return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, C_style=True), x1, x2, dtype=dtype) -def trunc(x, out=None, where=True, dtype=None): +def trunc(x, dtype=None): """ Returns the truncated value of the input, element-wise. @@ -2092,23 +1802,11 @@ def trunc(x, out=None, where=True, dtype=None): than `x` is. In short, the fractional part of the signed number `x` is discarded. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. Args: x (Tensor): input data. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -2128,15 +1826,15 @@ def trunc(x, out=None, where=True, dtype=None): >>> print(output) [-1. -1. -0. 0. 1. 1. 2.] """ - return _apply_tensor_op(fix, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(fix, x, dtype=dtype) -def exp(x, out=None, where=True, dtype=None): +def exp(x, dtype=None): """ Calculates the exponential of all elements in the input array. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. When `where` is provided, `out` must have a tensor value. `out` is not supported for storing the result, however it can be used in combination with `where` to set @@ -2146,15 +1844,6 @@ def exp(x, out=None, where=True, dtype=None): Args: x (Tensor): input data. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -2174,33 +1863,21 @@ def exp(x, out=None, where=True, dtype=None): >>> print(output) [ 1. 2.718282 7.3890557 20.085537 54.598145 ] """ - return _apply_tensor_op(F.tensor_exp, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_exp, x, dtype=dtype) -def expm1(x, out=None, where=True, dtype=None): +def expm1(x, dtype=None): """ Calculates ``exp(x) - 1`` for all elements in the array. Note: - Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - When `where` is provided, `out` must have a tensor value. `out` is not supported - for storing the result, however it can be used in combination with `where` to set - the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16, and np.float32. On CPU, the supported dtypes are np.float16, and np.float32. Args: x (Tensor): input data. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. @@ -2220,234 +1897,1960 @@ def expm1(x, out=None, where=True, dtype=None): >>> print(output) [ 0. 1.7182819 6.389056 19.085537 53.59815 ] """ - return _apply_tensor_op(F.tensor_expm1, x, out=out, where=where, dtype=dtype) + return _apply_tensor_op(F.tensor_expm1, x, dtype=dtype) -@constexpr -def _real_axes(ndim_orig, ndim_out, axes_orig): - """Returns the real axes to be reduced after performing broadcast""" - diff = ndim_out - ndim_orig - axes = F.make_range(diff) - axes_orig = map(functools.partial(operator.add, diff), axes_orig) - return axes + tuple(axes_orig) +def divmod_(x1, x2, dtype=None): + """ + Returns element-wise quotient and remainder simultaneously. + Args: + x1(Union[Tensor]): Dividend tensor. + x2(Union[Tensor, int, float, bool]): Divisor. If ``x1.shape != x2.shape``, + they must be broadcastable to a common shape. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. -@constexpr -def _shape_reduced_keepdims(shape, axes): + Returns: + Element-wise quotient and remainder from floor division, in format of (quotient, remainder) + + Raises: + TypeError: if `x1` and `x2` are not Tensor or scalar. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.array([1, 2, 3, 4, 5]) + >>> print(np.divmod(a, 1.5)) + (Tensor(shape=[5], dtype=Float32, + value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00, 2.00000000e+00, 3.00000000e+00]), + Tensor(shape=[5], dtype=Float32, + value= [ 1.00000000e+00, 5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 5.00000000e-01])) """ - Reduces dimensions corresponding to argument axes while - keeping the number of dimensions unchanged. + q = F.tensor_floordiv(x1, x2) + r = remainder(x1, x2) + if dtype is not None: + q = q.astype(dtype) + r = r.astype(dtype) + return (q, r) + + +def diff(a, n=1, axis=-1, prepend=None, append=None): """ - ndim_out = F.tuple_len(shape) - shape_out = [1]*ndim_out - for i in range(ndim_out): - if not i in axes: - shape_out[i] = shape[i] - return tuple(shape_out) + Calculates the n-th discrete difference along the given axis. + The first difference is given by :math:`out[i] = a[i+1] - a[i]` along the given axis, + higher differences are calculated by using `diff` iteratively. -@constexpr -def _shape_reduced(shape, axes): - """Removes dimensions corresponding to argument axes""" - ndim_orig = F.tuple_len(shape) - ndim_out = ndim_orig - F.tuple_len(axes) - shape_out = [0]*ndim_out - idx_out = 0 - for i in range(ndim_orig): - if not i in axes: - shape_out[idx_out] = shape[i] - idx_out += 1 - return tuple(shape_out) + Args: + a (Tensor): Input tensor. + n (int, optional): The number of times values are differenced. If zero, + the input is returned as-is. + axis (int, optional): The axis along which the difference is taken, default + is the last axis. + prepend/append (Tensor, optional): Values to prepend or append to a along + `axis` prior to performing the difference. Scalar values are expanded to + arrays with length 1 in the direction of `axis` and the shape of the input + array in along all other axes. Otherwise the dimension and shape must + match `a` except along axis. + + Returns: + The n-th differences. The shape of the output is the same as a except along + `axis` where the dimension is smaller by `n`. The type of the output is the same + as the type of the difference between any two elements of `a`. This is the same + as the type of `a` in most cases. + + Raises: + TypeError: If inputs have types not specified above. + ValueError: If ``n < 0``. + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` -def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True): - """Applies comparison based on cmp_fn and reduction based on reduce_fn""" + Examples: + >>> import mindspore.numpy as np + >>> arr = np.array([1, 3, -1, 0, 4]) + >>> print(np.diff(arr, n=2)) + [-6 5 3] + """ + # This implementation is inspired by jax.numpy _check_input_tensor(a) + axis = _canonicalize_axis(axis, a.ndim) + if not isinstance(n, int): + _raise_type_error("Input n should be int, but got ", n) + if n < 0: + _raise_value_error("Input n must > 0.") + if n == 0: + return a - shape = F.shape(a) - ndim = F.rank(a) - dtype = F.dtype(a) - axes = _check_axis_valid(axis, ndim) - if initial is not None: - if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or - not isinstance(initial, (int, float, bool, Tensor))): - _raise_type_error('initial should be scalar') + combined = () + if prepend is not None: + if isinstance(prepend, (int, float, bool)): + prepend = asarray_const(prepend) + prepend_shape = a.shape + prepend_shape = _tuple_setitem(prepend_shape, axis, 1) + prepend = _broadcast_to_shape(prepend, prepend_shape) + elif not isinstance(prepend, Tensor): + _raise_type_error("prepend must be scalar or Tensor, but got ", prepend) + combined += (prepend,) + + combined += (a,) + + if append is not None: + if isinstance(append, (int, float, bool)): + append = asarray_const(append) + append_shape = a.shape + append_shape = _tuple_setitem(append_shape, axis, 1) + append = _broadcast_to_shape(append, append_shape) + elif not isinstance(append, Tensor): + _raise_type_error("append must be scalar or Tensor, but got ", append) + combined += (append,) + + if combined: + a = concatenate(combined, axis) + + # if n > maximum length allowed, returns empty tensor, with shape matched with + # the original tensor + if n > a.shape[axis]: + empty_shape = a.shape + empty_shape = _tuple_setitem(empty_shape, axis, 0) + return empty(empty_shape, a.dtype) + + original_dtype = a.dtype + # will change once F.tensor_slice supports types other than float32 + if not _check_is_float(original_dtype): + a = a.astype(mstype.float32) + a = moveaxis(a, axis, -1) + for _ in F.make_range(n): + slice_start = _list_comprehensions(F.rank(a) - 1, 0, True) + slice_size = F.shape(a)[:-1] + (F.shape(a)[-1] - 1,) + minuend = F.tensor_slice(a, slice_start + (1,), slice_size) + subtrahend = F.tensor_slice(a, slice_start + (0,), slice_size) + a = F.tensor_sub(minuend, subtrahend) + if not _check_is_float(original_dtype): + a = a.astype(original_dtype) + return moveaxis(a, -1, axis) + + +def ediff1d(ary, to_end=None, to_begin=None): + """ + The differences between consecutive elements of a tensor. - if _is_shape_empty(shape): - if not axes: - return a - if keepdims: - shape_out = _shape_reduced_keepdims(shape, axes) + Args: + ary (Tensor): If necessary, will be flattened before the differences are taken. + to_end (Tensor or scalar, optional): Number(s) to append at the end of the + returned differences. + to_begin (Tensor or scalar, optional): Number(s) to prepend at the beginning + of the returned differences. + + Returns: + The differences. + + Raises: + TypeError: If inputs have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> arr = np.array([1, 3, -1, 0, 4]) + >>> print(np.ediff1d(arr)) + [ 2 -4 1 4] + """ + _check_input_tensor(ary) + combined = () + + if to_begin is not None: + if isinstance(to_begin, Tensor): + to_begin = to_begin.ravel() else: - shape_out = _shape_reduced(shape, axes) - if _is_shape_empty(shape_out): - return empty(F.dtype(a), shape_out) - if initial is None: - return _raise_value_error('initial value must be provided for zero-size arrays') - return full(shape_out, initial, dtype) + to_begin = _to_tensor(to_begin).ravel() + to_begin = to_begin.astype(ary.dtype) + combined += (to_begin,) - if initial is not None: - initial = full(shape, initial, dtype) - a = cmp_fn(a, initial) - if not axes: - return a - if isinstance(where, Tensor): - if initial is None: - return _raise_value_error('initial value must be provided for where masks') - ndim_orig = F.rank(a) - a = where_(where, a, initial) - axes = _real_axes(ndim_orig, F.rank(a), axes) + combined += (diff(ary.ravel()),) + + if to_end is not None: + if isinstance(to_end, Tensor): + to_end = to_end.ravel() + else: + to_end = _to_tensor(to_end).ravel() + to_end = to_end.astype(ary.dtype) + combined += (to_end,) - return reduce_fn(a, axes) + return P.Concat(0)(combined) -def positive(a, out=None, where=True, dtype=None): +def trapz(y, x=None, dx=1.0, axis=-1): """ - Numerical positive, element-wise. + Integrates along the given axis using the composite trapezoidal rule. + + Integrates `y` (x) along given axis. + + Args: + y (Tensor): Input array to integrate. + x (Union[int, float, bool, list, tuple, Tensor], optional): The sample points + corresponding to the `y` values. If `x` is None, the sample points are + assumed to be evenly spaced `dx` apart. The default is None. + dx (scalar, optional): The spacing between sample points when `x` is None. The + default is 1. + axis (int, optional): The axis along which to integrate. + + Returns: + Tensor of float, definite integral as approximated by trapezoidal rule. + + Raises: + ValueError: If axis is out of range of ``[-y.ndim, y.ndim)``. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.arange(6).reshape(2, 3) + >>> output = np.trapz(a, x=[-2, 1, 2], axis=1) + >>> print(output) + [ 3. 15.] + >>> output = np.trapz(a, dx=3, axis=0) + >>> print(output) + [ 4.5 7.5 10.5] + """ + y = _to_tensor(y) + ndim = F.rank(y) + _check_axis_in_range(axis, ndim) + axis = axis + ndim if axis < 0 else axis + y_start_axis_left = _list_comprehensions(axis, 0, True) + y_start_axis_right = _list_comprehensions(ndim - axis - 1, 0, True) + shape = F.shape(y) + y_slice_size = _tuple_setitem(shape, axis, shape[axis] - 1) + if x is not None: + x = _to_tensor(x) + dx = diff(x) + else: + dx = _to_tensor(dx) + dx = _expand(dx, ndim - axis, axis=-1) + dx = _broadcast_to_shape(dx, y_slice_size) + if not _check_is_float(F.dtype(y)): + # trapz returns float + y = F.cast(y, mstype.float32) + dx = F.cast(dx, F.dtype(y)) + + # product of dx and y with the last column removed + y_slice_left = F.tensor_slice(y, y_start_axis_left + (0,) + y_start_axis_right, y_slice_size) + prod_left = F.tensor_mul(y_slice_left, dx) + # product of dx and y with the first column removed + y_slice_right = F.tensor_slice(y, y_start_axis_left + (1,) + y_start_axis_right, y_slice_size) + prod_right = F.tensor_mul(y_slice_right, dx) + prod_sum = F.tensor_div(F.tensor_add(prod_left, prod_right), _to_tensor(2.0).astype(F.dtype(y))) + return F.reduce_sum(prod_sum, axis) + + +def _gcd(x1, x2): + """Calculates gcd without applying keyword arguments.""" + dtype = _promote(F.dtype(x1), F.dtype(x2)) + if _get_device() == 'CPU' and not _check_is_float(dtype): + # F.reduce_sum only supports float + x1 = F.cast(x1, mstype.float32) + x2 = F.cast(x2, mstype.float32) + x1 = F.absolute(x1) + x2 = F.absolute(x2) + cond_ge = F.tensor_ge(x1, x2) + a = where_(cond_ge, x1, x2) + b = where_(cond_ge, x2, x1) + b = where_(F.equal(b, ZERO_TENSOR), a, b) + r = _remainder(a, b) + while F.tensor_gt(F.reduce_sum(r), ZERO_TENSOR): + r = _remainder(a, b) + has_terminated = F.equal(r, ZERO_TENSOR) + a = where_(has_terminated, a, b) + b = where_(has_terminated, b, r) + if not _check_same_type(F.dtype(b), dtype): + b = F.cast(b, dtype) + return b + + +def gcd(x1, x2, dtype=None): + """ + Returns the greatest common divisor of ``|x1|`` and ``|x2|``. Note: - Numpy arguments casting, order, subok, signature, and extobj are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. Args: - a (Tensor): Input tensor. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. + x1 (Tensor): input data. + x2 (Tensor): input data. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor. + Tensor or scalar, the greatest common divisor of the absolute value of the inputs. + This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> import mindspore.numpy as np - >>> a = np.asarray([1, -1]).astype('float32') - >>> output = np.positive(a) + >>> output = np.gcd(np.arange(6), np.array(20)) >>> print(output) - [1. -1.] + [20 1 2 1 4 5] """ - _check_input_tensor(a) - neg_tensor = F.neg_tensor(a) - return _apply_tensor_op(F.neg_tensor, neg_tensor, out=out, where=where, dtype=dtype) + return _apply_tensor_op(_gcd, x1, x2, dtype=dtype) -def negative(a, out=None, where=True, dtype=None): +def lcm(x1, x2, dtype=None): """ - Numerical negative, element-wise. + Returns the lowest common multiple of ``|x1|`` and ``|x2|``. Note: - Numpy arguments `casting`, `order`, `subok`, `signature`, and `extobj` are + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. Args: - a (Tensor): Input tensor. - out (Tensor or None, optional): defaults to None. - where (Tensor or None, optional): For any non-default value of type other - than :class:`Tensor` or :class:`None`, the output retains its original value. - This condition is broadcasted over the input. At locations where the - condition is `True`, the out array will be set to the ufunc result. - Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default ``out=None``, - locations within it where the condition is `False` will remain - uninitialized. + x1 (Tensor): input data. + x2 (Tensor): input data. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor. + Tensor or scalar, the lowest common multiple of the absolute value of the inputs. + This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> import mindspore.numpy as np - >>> a = np.asarray([1, -1]).astype('float32') - >>> output = np.negative(a) + >>> output = np.lcm(np.arange(6), np.array(20)) >>> print(output) - [-1. 1.] + [ 0 20 20 60 20 20] """ - _check_input_tensor(a) - return _apply_tensor_op(F.neg_tensor, a, out=out, where=where, dtype=dtype) + def _lcm(x1, x2): + """Calculates lcm without applying keyword arguments""" + common_divisor = _gcd(x1, x2) + q1 = F.tensor_div(x1, common_divisor) + q2 = F.tensor_div(x2, common_divisor) + res = F.tensor_mul(F.tensor_mul(q1, q2), common_divisor) + dtype = F.dtype(res) + if _get_device() == 'CPU' and not _check_is_float(dtype): + # F.absolute only supports float + res = F.cast(res, mstype.float32) + return F.absolute(res).astype(dtype) + return _apply_tensor_op(_lcm, x1, x2, dtype=dtype) -def cumsum(a, axis=None, dtype=None): + +def convolve(a, v, mode='full'): """ - Returns the cumulative sum of the elements along a given axis. + Returns the discrete, linear convolution of two one-dimensional sequences. + + Note: + If `v` is longer than `a`, the tensors are swapped before computation. Args: - a (Tensor): Input tensor. - axis (int, optional): Axis along which the cumulative sum is computed. The - default (None) is to compute the cumsum over the flattened array. - dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, - unless `a` has an integer dtype with a precision less than that of the - default platform integer. In that case, the default platform integer - is used. + a (Union[list, tuple, Tensor]): First one-dimensional input tensor. + v (Union[list, tuple, Tensor]): Second one-dimensional input tensor. + + mode (str, optional): By default, mode is `\'full\'`. This returns the + convolution at each point of overlap, with an output shape of :math:`(N+M-1,)`. + At the end-points of the convolution, the signals do not overlap completely, + and boundary effects may be seen. + If `mode` is `\'same\'`, it returns output of length :math:`max(M, N)`. Boundary + effects are still visible. + If `mode` is `\'valid\'`, it returns output of length :math:`max(M, N) - min(M, N) + 1`. + The convolution product is only given for points where the signals overlap + completely. Values outside the signal boundary have no effect. Returns: - Tensor. + Tensor, discrete, linear convolution of a and v. Raises: - TypeError: If input arguments have types not specified above. - ValueError: If axis is out of range. + TypeError: if the inputs have types not specified above. + ValueError: if a and v are empty or have wrong dimensions Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore.numpy as np - >>> output = np.cumsum(np.ones((3,3)), axis=0) + >>> output = np.convolve([1., 2., 3., 4., 5.], [2., 3.], mode="valid") >>> print(output) - [[1. 1. 1.] - [2. 2. 2.] - [3. 3. 3.]] - """ - _check_input_tensor(a) - original_dtype = F.dtype(a) - # If original array is int, and has precision less then int32, convert to int32 - if _check_same_type(original_dtype, mstype.bool_) or \ - _check_same_type(original_dtype, mstype.int8) or \ - _check_same_type(original_dtype, mstype.int16): - original_dtype = mstype.int32 - a = a.astype(mstype.float32) - if axis is None: - a = a.ravel() - axis = 0 - _check_axis_in_range(axis, a.ndim) - if dtype is not None and not _check_same_type(original_dtype, dtype): - return _cumsum_default(a, axis).astype(dtype, copy=False) - return _cumsum_default(a, axis).astype(original_dtype, copy=False) + [ 3. 6. 9. 12.] + """ + if not isinstance(a, Tensor): + a = asarray_const(a) + if not isinstance(v, Tensor): + v = asarray_const(v) + if a.size == 0 or v.size == 0: + _raise_value_error("Inputs cannot be empty.") + a = _expand(a, 1) + v = _expand(v, 1) + final_dtype = _promote(a.dtype, v.dtype) + a = a.astype("float32") + v = v.astype("float32") + if a.ndim != 1 or v.ndim != 1: + _raise_value_error("a and v must be 1-D tensor.") + if a.size < v.size: + a, v = v, a + v = v[::-1] + if mode not in ('same', 'full', 'valid'): + _raise_value_error("mode must be one of ['full', 'same', 'valid']") + if v.size > 1: + if mode == 'same': + pad_left = _to_tensor(_list_comprehensions(v.size // 2, 0.0, True)) + pad_right = _to_tensor(_list_comprehensions(v.size - v.size // 2 - 1, 0.0, True)) + a = P.Concat(axis=0)((pad_left, a, pad_right)) + elif mode == 'full': + pad = _to_tensor(_list_comprehensions(v.size - 1, 0.0, True)) + a = P.Concat(axis=0)((pad, a, pad)) + a = a.reshape(1, 1, 1, a.size) + v = v.reshape(1, 1, 1, v.size) + _conv = P.Conv2D(out_channel=1, kernel_size=(1, v.size), pad_mode="valid") + return _conv(a, v).reshape(-1).astype(final_dtype) + + +def _handle_weights(weights, num_samples): + """Checks fweight and aweight in np.cov.""" + weights = asarray_const(weights) + if not _check_is_int(weights.dtype): + _raise_type_error("weights must be integer") + weights = weights.astype("float32") + if weights.ndim > 1: + _raise_runtime_error("cannot handle multidimensional weights") + if weights.shape[0] != num_samples: + _raise_runtime_error("incompatible numbers of samples and weights") + return absolute(weights) + + +def _handle_inputs(cov_input, rowvar): + """Checks input arrays for np.cov.""" + if not isinstance(cov_input, Tensor): + cov_input = asarray_const(cov_input) + if cov_input.ndim > 2: + _raise_value_error("input array has dimension more than 2.") + cov_input = cov_input.astype("float32") + cov_input = _expand(cov_input, 2) + if not rowvar and cov_input.shape[0] != 1: + cov_input = cov_input.T + return cov_input + + +def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights=None, dtype=None): + """ + Estimates a covariance matrix, given data and weights. + + Covariance indicates the level to which two variables vary together. If we examine + N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`, then the covariance matrix + element :math:`C_{ij}` is the covariance of :math:`x_i` and :math:`x_j`. The element + :math:`C_{ii}` is the variance of :math:`x_i`. + Note: + `fweights` and `aweights` must be all positive, in Numpy if negative values + are detected, a value error will be raised, in MindSpore we converts all values + to positive instead. -def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None): - """Applies tensor operations based on fn""" - _check_input_tensor(*args) - res = fn(*args) - - # if out is set to a non-default value, return tensor will have the same - # dtype as out, which overrides the dtype passed into the keyword argument - if isinstance(out, Tensor): - dtype_out = F.dtype(out) - elif dtype is not None: - dtype_out = dtype + Args: + m (Union[Tensor, list, tuple]): A 1-D or 2-D tensor containing multiple variables + and observations. Each row of `m` represents a variable, and each column + represents a single observation of all those variables. Also see `rowvar` below. + y (Union[Tensor, list, tuple], optional): An additional set of variables + and observations. `y` has the same form as that of `m`. + rowvar(bool, optional): If `rowvar` is ``True`` (default), then each row represents + a variable, with observations in the columns. Otherwise, the relationship + is transposed: each column represents a variable, while the rows contain + observations. + bias (bool, optional): Default normalization (``False``) is by :math:`(N - 1)`, where + :math:`N` is the number of observations given (unbiased estimate). If bias is + ``True``, then normalization is by `N`. These values can be overridden by + using the keyword `ddof`. + ddof (int, optional): If not ``None``, the default value implied by `bias` is + overridden. Note that :math:`ddof=1` will return the unbiased estimate, even + if both fweights and aweights are specified, and :math:`ddof=0` will return + the simple average. See the notes for the details. The default value + is ``None``. + fweights (Union[Tensor, list, tuple], optional): 1-D tensor of integer + frequency weights; the number of times each observation vector should + be repeated. + aweights (Union[Tensor, list, tuple], optional): 1-D tensor of observation + vector weights. These relative weights are typically larger for observations + considered more important and smaller for observations considered less + important. If :math:`ddof=0` the tensor of weights can be used to assign probabilities + to observation vectors. + dtype (Union[:class:`mindspore.dtype`, str], optional): Data-type of the + result. By default, the return data-type will have mstype.float32 precision. + + Returns: + Tensor, the covariance matrix of the variables. + + Raises: + TypeError: if the inputs have types not specified above. + ValueError: if `m` and `y` have wrong dimensions. + RuntimeError: if `aweights` and `fweights` have dimensions > 2. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.cov([[2., 3., 4., 5.], [0., 2., 3., 4.], [7., 8., 9., 10.]]) + >>> print(output) + [[1.6666666 2.1666667 1.6666666] + [2.1666667 2.9166667 2.1666667] + [1.6666666 2.1666667 1.6666666]] + """ + # This implementation was inspired by original numpy implementation. + m = _handle_inputs(m, rowvar) + + if m.shape[0] == 0: + return empty((0, 0), dtype="float32") + + if y is not None: + y = _handle_inputs(y, rowvar) + m = concatenate((m, y), axis=0) + + if ddof is None: + if not bias: + ddof = 1 + else: + ddof = 0 + + # Handle fweights and aweights + w = _handle_weights(fweights, m.shape[1]) if fweights is not None else None + + if aweights is not None: + aweights = _handle_weights(aweights, m.shape[1]) + w = aweights if w is None else w * aweights + + avg = average(m, axis=1, weights=w) + + # Determine the normalization + if w is None: + fact = m.shape[1] - ddof else: - dtype_out = F.dtype(res) + w_sum = _reduce_sum_default(w, -1) + if ddof == 0: + fact = w_sum + elif aweights is None: + fact = w_sum - ddof + else: + fact = w_sum - ddof * F.reduce_sum(w * aweights) / w_sum + + m = m - F.expand_dims(avg, -1) + if w is None: + m_T = m.T + else: + m_T = (m * w).T + res = true_divide(dot(m, m_T), fact).squeeze() + if dtype is not None: + return res.astype(dtype) + return res + + +@constexpr +def _real_axes(ndim_orig, ndim_out, axes_orig): + """Returns the real axes to be reduced after performing broadcast""" + _diff = ndim_out - ndim_orig + axes = F.make_range(_diff) + axes_orig = map(functools.partial(operator.add, _diff), axes_orig) + return axes + tuple(axes_orig) + + +@constexpr +def _shape_reduced_keepdims(shape, axes): + """ + Reduces dimensions corresponding to argument axes while + keeping the number of dimensions unchanged. + """ + ndim_out = F.tuple_len(shape) + shape_out = [1]*ndim_out + for i in range(ndim_out): + if not i in axes: + shape_out[i] = shape[i] + return tuple(shape_out) + + +@constexpr +def _shape_reduced(shape, axes): + """Removes dimensions corresponding to argument axes""" + ndim_orig = F.tuple_len(shape) + ndim_out = ndim_orig - F.tuple_len(axes) + shape_out = [0]*ndim_out + idx_out = 0 + for i in range(ndim_orig): + if not i in axes: + shape_out[idx_out] = shape[i] + idx_out += 1 + return tuple(shape_out) + + +def _reduce(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None): + """ + Applies comparison based on cmp_fn and reduction based on reduce_fn. + If cmp_fn is None, only reduction is performed. + """ + _check_input_tensor(a) + + shape = F.shape(a) + ndim = F.rank(a) + if dtype is None: + dtype = F.dtype(a) + axes = _check_axis_valid(axis, ndim) + + if _is_shape_empty(shape): + if not axes: + return a + if keepdims: + shape_out = _shape_reduced_keepdims(shape, axes) + else: + shape_out = _shape_reduced(shape, axes) + if _is_shape_empty(shape_out): + return empty(shape_out, dtype) + if initial is None: + if cmp_fn is None: + initial = nan + else: + return _raise_value_error('initial value must be provided for zero-size arrays') + return full(shape_out, initial, dtype) + + if initial is not None: + initial = full(shape, initial, dtype) + a = cmp_fn(a, initial) + if not axes: + return a.astype(dtype) + if isinstance(where, Tensor): + if initial is None: + return _raise_value_error('initial value must be provided for where masks') + ndim_orig = F.rank(a) + a = where_(where, a, initial) + axes = _real_axes(ndim_orig, F.rank(a), axes) + + return reduce_fn(a, axes).astype(dtype) + + +def _reduce_nansum(x, axis, keepdims=False): + """Computes reduce sum treating NaNs as zeros.""" + x = F.select(_isnan(x), zeros(F.shape(x), F.dtype(x)), x) + if keepdims: + return _reduce_sum_keepdims(x, axis) + return _reduce_sum_default(x, axis) + + +def nansum(a, axis=None, dtype=None, keepdims=False): + """ + Returns the sum of array elements over a given axis treating Not a Numbers (NaNs) as zero. + + Note: + Numpy arguments `out` is not supported. + + Args: + a (Union[int, float, bool, list, tuple, Tensor]): Array containing numbers + whose sum is desired. If `a` is not an array, a conversion is attempted. + axis (Union[int, tuple of int, None], optional): Axis or axes along which the sum is + computed. The default is to compute the sum of the flattened array. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + keepdims (boolean, optional): defaults to False. If this is set to True, the axes which + are reduced are left in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + Returns: + Tensor. + + Raises: + ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or + if the axes contain duplicates. + + Supported Platforms: + ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([[1, 1], [1, np.nan]]) + >>> output = np.nansum(a) + >>> print(output) + 3.0 + >>> output = np.nansum(a, axis=0) + >>> print(output) + [2. 1.] + """ + a = _to_tensor(a) + nan_mask = _isnan(a) + a = F.select(nan_mask, zeros(F.shape(a), F.dtype(a)), a) + if dtype is None and _get_device() == 'CPU' and not _check_is_float(F.dtype(a)): + # F.reduce_sum only supports float on CPU + dtype = F.dtype(a) + a = F.cast(a, mstype.float32) + return _reduce(a, functools.partial(_reduce_nansum, keepdims=keepdims), axis=axis, + keepdims=keepdims, dtype=dtype) + + +def _count_nonnan(a, axis, keepdims=False): + """Counts the number of elements excluding NaNs.""" + nonnan_mask = F.select(_isnan(a), zeros(F.shape(a), F.dtype(a)), ones(F.shape(a), F.dtype(a))) + if keepdims: + return _reduce_sum_keepdims(nonnan_mask, axis) + return _reduce_sum_default(nonnan_mask, axis) + + +def nanmean(a, axis=None, dtype=None, keepdims=False): + """ + Computes the arithmetic mean along the specified axis, ignoring NaNs. + + Returns the average of the array elements. The average is taken over the flattened + array by default, otherwise over the specified axis. float32 intermediate and + return values are used for integer inputs. + + Note: + Numpy arguments `out` is not supported. + + Args: + a (Union[int, float, bool, list, tuple, Tensor]): Array containing numbers + whose mean is desired. If `a` is not an array, a conversion is attempted. + axis (Union[int, tuple of int, None], optional): Axis or axes along which the mean is + computed. The default is to compute the mean of the flattened array. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + keepdims (boolean, optional): defaults to False. If this is set to True, the axes which + are reduced are left in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + Returns: + Tensor. + + Raises: + ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or + if the axes contain duplicates. + + Supported Platforms: + ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([[1, np.nan], [3, 4]]) + >>> output = np.nanmean(a) + >>> print(output) + 2.6666667 + >>> output = np.nanmean(a, axis=0) + >>> print(output) + [2. 4.] + >>> output = np.nanmean(a, axis=1) + >>> print(output) + [1. 3.5] + """ + a = _to_tensor(a) + axis = _check_axis_valid(axis, F.rank(a)) + sum_a = nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) + return F.tensor_div(sum_a, _count_nonnan(a, axis, keepdims)) + + +def _nanvar(a, axis, ddof=0, keepdims=False): + """Computes nanvar without applying keyword arguments.""" + mean_a = nanmean(a, axis=axis, keepdims=True) + pow_a = F.tensor_pow(F.tensor_sub(a, mean_a), 2) + sum_a = _reduce_nansum(pow_a, axis, keepdims) + count = _count_nonnan(a, axis, keepdims) + return F.tensor_div(sum_a, F.tensor_sub(count, ddof)) + + +def nanvar(a, axis=None, dtype=None, ddof=0, keepdims=False): + """ + Computes the variance along the specified axis, while ignoring NaNs. + + Returns the variance of the array elements, a measure of the spread of a distribution. The + variance is computed for the flattened array by default, otherwise over the specified axis. + + Note: + Numpy arguments `out` is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + a (Union[int, float, bool, list, tuple, Tensor]): Array containing numbers + whose variance is desired. If `a` is not an array, a conversion is attempted. + axis (Union[int, tuple of int, None], optional): Axis or axes along which the variance is + computed. The default is to compute the variance of the flattened array. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + ddof (int, optional): “Delta Degrees of Freedom”: the divisor used in the calculation is + ``N - ddof``, where `N` represents the number of non-NaN elements. By default `ddof` + is zero. + keepdims (boolean, optional): defaults to False. If this is set to True, the axes which + are reduced are left in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + Returns: + Tensor. + + Raises: + ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or + if the axes contain duplicates. + + Supported Platforms: + ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([[1, np.nan], [3, 4]]) + >>> output = np.nanstd(a) + >>> print(output) + 1.2472192 + >>> output = np.nanstd(a, axis=0) + >>> print(output) + [1. 0.] + >>> output = np.nanstd(a, axis=1) + >>> print(output) + [0. 0.5] + """ + return _reduce(a, functools.partial(_nanvar, ddof=ddof, keepdims=keepdims), axis=axis, + keepdims=keepdims, dtype=dtype) + + +def nanstd(a, axis=None, dtype=None, ddof=0, keepdims=False): + """ + Computes the standard deviation along the specified axis, while ignoring NaNs. + + Returns the standard deviation, a measure of the spread of a distribution, of the non-NaN + array elements. The standard deviation is computed for the flattened array by default, + otherwise over the specified axis. + + Note: + Numpy arguments `out` is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + a (Union[int, float, bool, list, tuple, Tensor]): Calculates the standard deviation of the non-NaN values. + axis (Union[int, tuple of int, None], optional): Axis or axes along which the standard + deviation is computed. The default is to compute the standard deviation of the + flattened array. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + ddof (int, optional): “Delta Degrees of Freedom”: the divisor used in the calculation is + ``N - ddof``, where `N` represents the number of non-NaN elements. By default `ddof` + is zero. + keepdims (boolean, optional): defaults to False. If this is set to True, the axes which + are reduced are left in the result as dimensions with size one. With this option, + the result will broadcast correctly against the original `a`. + + Returns: + Tensor. + + Raises: + ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or + if the axes contain duplicates. + + Supported Platforms: + ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([[1, np.nan], [3, 4]]) + >>> output = np.nanvar(a) + >>> print(output) + 1.5555557 + >>> output = np.nanvar(a, axis=0) + >>> print(output) + [1. 0.] + >>> output = np.nanvar(a, axis=1) + >>> print(output) + [0. 0.25] + """ + return _reduce(a, lambda a, axis: F.sqrt(_nanvar(a, axis, ddof=ddof, keepdims=keepdims)), + axis=axis, keepdims=keepdims, dtype=dtype) + + +def exp2(x, dtype=None): + """ + Calculates ``2**p`` for all p in the input array. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): input values. + dtype (:class:`mindspore.dtype`, optional): defaults to :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise 2 to the power `x`. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.array([2, 3]).astype(np.float32) + >>> output = np.exp2(x) + >>> print(output) + [4. 8.] + """ + return _apply_tensor_op(lambda x: F.tensor_pow(2, x), x, dtype=dtype) + + +def kron(a, b): + """ + Kronecker product of two arrays. + + Computes the Kronecker product, a composite array made of blocks of the second + array scaled by the first. + + Args: + a (Union[int, float, bool, list, tuple, Tensor]): input values. + b (Union[int, float, bool, list, tuple, Tensor]): input values. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.kron([1,10,100], [5,6,7]) + >>> print(output) + [ 5 6 7 50 60 70 500 600 700] + >>> output = np.kron([5,6,7], [1,10,100]) + >>> print(output) + [ 5 50 500 6 60 600 7 70 700] + >>> output = np.kron(np.eye(2), np.ones((2,2))) + >>> print(output) + [[1. 1. 0. 0.] + [1. 1. 0. 0.] + [0. 0. 1. 1.] + [0. 0. 1. 1.]] + """ + a, b = _to_tensor(a, b) + ndim = _max(F.rank(a), F.rank(b)) + if ndim == 0: + return F.tensor_mul(a, b) + a = _expand(a, ndim) + b = _expand(b, ndim) + shape_a = F.shape(a) + shape_b = F.shape(b) - if isinstance(out, Tensor) and isinstance(where, Tensor): - out = where_(where, res, out) - elif out is None or where is not None: - out = res + # scales a by the shape of b + kron_shape = _seq_prod(shape_a, shape_b) + a = F.reshape(a, _add_unit_axes(shape_a, 2*ndim, True)) + a = F.tile(a, _add_unit_axes(shape_b, 2*ndim, False)) + a = moveaxis(a, F.make_range(ndim, 2*ndim), F.make_range(1, 2*ndim, 2)) + a = F.reshape(a, kron_shape) + # scales b by the shape of a + b = F.tile(b, shape_a) + return F.tensor_mul(a, b) - if not _check_same_type(F.dtype(out), dtype_out): - out = F.cast(out, dtype_out) - return out +def cross(a, b, axisa=- 1, axisb=- 1, axisc=- 1, axis=None): + """ + Returns the cross product of two (arrays of) vectors. + + The cross product of `a` and `b` in :math:`R^3` is a vector perpendicular to both + `a` and `b`. If `a` and `b` are arrays of vectors, the vectors are defined by the + last axis of `a` and `b` by default, and these axes can have dimensions 2 or 3. + Where the dimension of either `a` or `b` is 2, the third component of the input + vector is assumed to be zero and the cross product calculated accordingly. In cases + where both input vectors have dimension 2, the z-component of the cross product is + returned. + + Args: + a (Union[int, float, bool, list, tuple, Tensor]): Components of the first vector(s). + b (Union[int, float, bool, list, tuple, Tensor]): Components of the second vector(s). + axisa (int, optional): Axis of `a` that defines the vector(s). By default, the last + axis. + axisb (int, optional): Axis of `b` that defines the vector(s). By default, the last + axis. + axisc (int, optional): Axis of `c` containing the cross product vector(s). Ignored + if both input vectors have dimension 2, as the return is scalar. By default, + the last axis. + axis (int, optional): If defined, the axis of `a`, `b` and `c` that defines the + vector(s) and cross product(s). Overrides `axisa`, `axisb` and `axisc`. + + Returns: + Tensor, vector cross product(s). + + Raises: + ValueError: when the dimensions of the vector(s) in `a` and/or `b` equal 2 or 3. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.array([[1,2,3], [4,5,6]]) + >>> y = np.array([[4,5,6], [1,2,3]]) + >>> output = np.cross(x, y) + >>> print(output) + [[-3 6 -3] + [ 3 -6 3]] + >>> output = np.cross(x, y, axisc=0) + [[-3 3] + [ 6 -6] + [-3 3]] + """ + a, b = _to_tensor(a, b) + if axis is not None: + axisa, axisb, axisc = axis, axis, axis + + _check_axis_in_range(axisa, F.rank(a)) + _check_axis_in_range(axisb, F.rank(b)) + a = moveaxis(a, axisa, -1) + b = moveaxis(b, axisb, -1) + shape_a = F.shape(a) + shape_b = F.shape(b) + if F.shape(a)[-1] not in (2, 3) or F.shape(b)[-1] not in (2, 3): + _raise_value_error('incompatible dimensions for cross product (dimension must be 2 or 3)') + a_has_z = shape_a[-1] == 3 + b_has_z = shape_b[-1] == 3 + shape_out = _infer_out_shape(shape_a[:-1], shape_b[:-1]) + if a_has_z or b_has_z: + shape_out += (3,) + _check_axis_in_range(axisc, len(shape_out)) + + dtype = _promote(F.dtype(a), F.dtype(b)) + if _get_device() == 'CPU': + # F.tensor_slice only supports float on CPU + if not _check_is_float(F.dtype(a)): + a = F.cast(a, mstype.float32) + if not _check_is_float(F.dtype(b)): + b = F.cast(b, mstype.float32) + + a_slice_start = _list_comprehensions(F.rank(a) - 1, 0, True) + a_slice_size = shape_a[:-1] + (1,) + b_slice_start = _list_comprehensions(F.rank(b) - 1, 0, True) + b_slice_size = shape_b[:-1] + (1,) + + def _get_slice_product(idx_a, idx_b): + return multiply(F.tensor_slice(a, a_slice_start + (idx_a,), a_slice_size), + F.tensor_slice(b, b_slice_start + (idx_b,), b_slice_size)) + + cz = F.tensor_sub(_get_slice_product(0, 1), _get_slice_product(1, 0)) # ax*by - ay*bx + if not a_has_z and not b_has_z: + return F.reshape(cz, shape_out).astype(dtype) + + if a_has_z and b_has_z: + cx = F.tensor_sub(_get_slice_product(1, 2), _get_slice_product(2, 1)) # ay*bz - az*by + cy = F.tensor_sub(_get_slice_product(2, 0), _get_slice_product(0, 2)) # az*bx - ax*bz + elif a_has_z: + cx = F.neg_tensor(_get_slice_product(2, 1)) # -az*by + cy = _get_slice_product(0, 2) # az*bx + else: # b_has_z + cx = _get_slice_product(1, 2) # ay*bz + cy = F.neg_tensor(_get_slice_product(0, 2)) # -ax*bz + res = _concat((cx, cy, cz)).reshape(shape_out) + return moveaxis(res, -1, axisc).astype(dtype) + + +def ceil(x, dtype=None): + """ + Returns the ceiling of the input, element-wise. + + The ceil of the scalar `x` is the smallest integer `i`, such that ``i >= x``. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): input values. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the floor of each element in `x`. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) + >>> output = np.ceil(a) + >>> print(output) + [-1. -1. -0. 1. 2. 2. 2.] + """ + return _apply_tensor_op(lambda x: F.neg_tensor(F.floor(F.neg_tensor(x.astype(mstype.float32)))), + x, dtype=dtype) + + +def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): + """Infers the shape of the last two dimensions after performing matmul.""" + shape_rem = () + if ndim1 >= 2: + shape_rem += (shape1[-2],) + if transpose_b: + if ndim2 >= 2: + shape_rem += (shape2[-2],) + else: + if ndim1 >= 1: + shape_rem += (shape2[-1],) + return shape_rem + + +def positive(a, dtype=None): + """ + Numerical positive, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + a (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, -1]).astype('float32') + >>> output = np.positive(a) + >>> print(output) + [1. -1.] + """ + _check_input_tensor(a) + neg_tensor = F.neg_tensor(a) + return _apply_tensor_op(F.neg_tensor, neg_tensor, dtype=dtype) + + +def negative(a, dtype=None): + """ + Numerical negative, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + a (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, -1]).astype('float32') + >>> output = np.negative(a) + >>> print(output) + [-1. 1.] + """ + return _apply_tensor_op(F.neg_tensor, a, dtype=dtype) + + +def cumsum(a, axis=None, dtype=None): + """ + Returns the cumulative sum of the elements along a given axis. + + Note: + If ``a.dtype`` is :class:`int8`, :class:`int16` or :class:`bool`, the result + `dtype` will be elevated to :class:`int32`. + + Args: + a (Tensor): Input tensor. + axis (int, optional): Axis along which the cumulative sum is computed. The + default (None) is to compute the cumsum over the flattened array. + dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, + unless `a` has an integer dtype with a precision less than that of the + default platform integer. In that case, the default platform integer + is used. + + Returns: + Tensor. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If axis is out of range. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.cumsum(np.ones((3,3)), axis=0) + >>> print(output) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + """ + _check_input_tensor(a) + original_dtype = F.dtype(a) + # If original tensor is int, and has precision less then int32, convert to int32 + if _check_same_type(original_dtype, mstype.bool_) or \ + _check_same_type(original_dtype, mstype.int8) or \ + _check_same_type(original_dtype, mstype.int16): + original_dtype = mstype.int32 + a = a.astype(mstype.float32) + if axis is None: + a = a.ravel() + axis = 0 + _check_axis_in_range(axis, a.ndim) + if dtype is not None and not _check_same_type(original_dtype, dtype): + return _cumsum_default(a, axis).astype(dtype, copy=False) + return _cumsum_default(a, axis).astype(original_dtype, copy=False) + + +def nancumsum(a, axis=None, dtype=None): + """ + Return the cumulative sum of array elements over a given axis treating Not a Numbers (NaNs) + as zero. The cumulative sum does not change when NaNs are encountered and leading NaNs are + replaced by zeros. + + Zeros are returned for slices that are all-NaN or empty. + + Note: + If ``a.dtype`` is :class:`int8`, :class:`int16` or :class:`bool`, the result + `dtype` will be elevated to :class:`int32`. + + Args: + a (Tensor): Input tensor. + axis (int, optional): Axis along which the cumulative sum is computed. The + default (None) is to compute the cumsum over the flattened array. + dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, + unless `a` has an integer dtype with a precision less than that of the + default platform integer. In that case, the default platform integer + is used. + + Returns: + Tensor. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If axis is out of range. + + Supported Platforms: + ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([[1, 2], [3, np.nan]]) + >>> output = np.nancumsum(a) + >>> print(output) + [1. 3. 6. 6.] + >>> output = np.nancumsum(a, axis=0) + >>> print(output) + [[1. 2.] + [4. 2.]] + >>> output = np.nancumsum(a, axis=1) + >>> print(output) + [[1. 3.] + [3. 3.]] + """ + a = F.select(_isnan(a), zeros(F.shape(a), F.dtype(a)), a) + return cumsum(a, axis=axis, dtype=dtype) + + +def cbrt(x, dtype=None): + """ + Returns the cube-root of a tensor, element-wise. + + Note: + Numpy arguments `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, -1, 3, -8, 64]) + >>> output = np.cbrt(a) + >>> print(output) + [ 1. -1. 1.4422495 -2. 4. ] + """ + def _cbrt(x): + compute_type = promote_types(x.dtype, "float32") + x = x.astype(compute_type) + # TODO: use P.Sign() once gpu support is added + abs_x = F.absolute(x) + sign_x = abs_x / x + return sign_x * F.tensor_pow(abs_x, 1. / 3.) + return _apply_tensor_op(_cbrt, x, dtype=dtype) + + +def log1p(x, dtype=None): + """ + Returns the natural logarithm of one plus the input array, element-wise. + + Calculates ``log(1 + x)``. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input array. + dtype (:class:`mindspore.dtype`): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([1, 2, 3]).astype('float16') + >>> output = np.log1p(x) + >>> print(output) + [0.6934 1.099 1.387 ] + """ + return _apply_tensor_op(lambda x: F.log(x + 1), x, dtype=dtype) + + +def logaddexp(x1, x2, dtype=None): + """ + Logarithm of the sum of exponentiations of the inputs. + + Calculates ``log(exp(x1) + exp(x2))``. This function is useful in statistics where the + calculated probabilities of events may be so small as to exceed the range of normal + floating point numbers. In such cases the logarithm of the calculated probability is + stored. This function allows adding probabilities stored in such a fashion. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x1 (Tensor): Input array. + x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + dtype (:class:`mindspore.dtype`): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.array([1, 2, 3]).astype('float16') + >>> x2 = np.array(2).astype('float16') + >>> output = np.logaddexp(x1, x2) + >>> print(output) + [2.312 2.693 3.312] + """ + def _logaddexp(x1, x2): + return F.log(F.tensor_add(F.tensor_exp(x1), F.tensor_exp(x2))) + return _apply_tensor_op(_logaddexp, x1, x2, dtype=dtype) + + +def log2(x, dtype=None): + """ + Base-2 logarithm of `x`. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([2, 4, 8]).astype('float16') + >>> output = np.log2(x) + >>> print(output) + [1. 2. 3.] + """ + tensor_2 = _make_tensor(2, x.dtype) + def _log2(x): + return F.log(x) / F.log(tensor_2) + return _apply_tensor_op(_log2, x, dtype=dtype) + + +def logaddexp2(x1, x2, dtype=None): + """ + Logarithm of the sum of exponentiations of the inputs in base of 2. + + Calculates ``log2(2**x1 + 2**x2)``. + This function is useful in machine learning when the calculated probabilities of events + may be so small as to exceed the range of normal floating point numbers. + In such cases the base-2 logarithm of the calculated probability can be used instead. + This function allows adding probabilities stored in such a fashion. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x1 (Tensor): Input tensor. + x2 (Tensor): Input tensor. If ``x1.shape != x2.shape``, they must be broadcastable to + a common shape (which becomes the shape of the output). + dtype (:class:`mindspore.dtype`): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.array([2, 4, 8]).astype('float16') + >>> x2 = np.array(2).astype('float16') + >>> output = np.logaddexp2(x1, x2) + >>> print(output) + [3. 4.32 8.02] + """ + _check_input_tensor(x1, x2) + add_exp = F.tensor_add(F.tensor_pow(2, x1), F.tensor_pow(2, x2)) + return log2(add_exp, dtype=dtype) + + +def log10(x, dtype=None): + """ + Base-10 logarithm of `x`. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([10, 100, 1000]).astype('float16') + >>> output = np.log10(x) + >>> print(output) + [1. 2. 3.] + """ + tensor_10 = _make_tensor(10, x.dtype) + def _log10(x): + return F.log(x) / F.log(tensor_10) + return _apply_tensor_op(_log10, x, dtype=dtype) + + +def _cast_type_for_trigonometric(x): + _check_input_tensor(x) + if x.dtype != mstype.float16 or x.dtype != mstype.float32 or x.dtype != mstype.float64: + dtype = _promote_for_trigonometric(x.dtype) + x = F.cast(x, dtype) + return x + + +def sin(x, dtype=None): + """ + Trigonometric sine, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([-5, -1, 0, 2, 4, 100]).astype('float32') + >>> output = np.sin(x) + >>> print(output) + [ 0.9589243 -0.84147096 0. 0.9092974 -0.7568025 -0.50636566] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.sin, x, dtype=dtype) + + +def cos(x, dtype=None): + """ + Cosine element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(5).astype('float32') + >>> print(np.cos(x)) + [ 1. 0.5403023 -0.41614684 -0.9899925 -0.6536436 ] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.cos, x, dtype=dtype) + + +def tan(x, dtype=None): + """ + Computes tangent element-wise. + + Equivalent to :math:`np.sin(x)/np.cos(x)` element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: If the input is not a tensor or is :class:`tensor.dtype` is :class:`mindsproe.float64`. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([-5, -1, 0, 2, 4, 100]).astype('float32') + >>> print(np.tan(x)) + [ 3.380515 -1.5574077 0. -2.1850398 1.1578213 -0.58721393] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.tan, x, dtype=dtype) + + +def arcsin(x, dtype=None): + """ + Inverse sine, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. y-coordinate on the unit circle. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Raises: + TypeError: If the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.asarray([1, -1], np.float32) + >>> output = np.arcsin(x) + >>> print(output) + [ 1.5707964 -1.5707964] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.asin, x, dtype=dtype) + + +def arccos(x, dtype=None): + """ + Trigonometric inverse cosine, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. x-coordinate on the unit circle. + For real arguments, the domain is :math:`[-1, 1]`. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Raises: + TypeError: If the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.asarray([1, -1], np.float32) + >>> output = np.arccos(x) + >>> print(output) + [0. 3.1415927] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.acos, x, dtype=dtype) + + +def arctan(x, dtype=None): + """ + Trigonometric inverse tangent, element-wise. + + The inverse of tan, so that if :math:`y = tan(x)` then :math:`x = arctan(y)`. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(5).astype('float32') + >>> print(np.tan(x)) + [ 0. 1.5574077 -2.1850398 -0.14254655 1.1578213 ] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.atan, x, dtype=dtype) + + +def sinh(x, dtype=None): + """ + Hyperbolic sine, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(5).astype('float32') + >>> print(np.sinh(x)) + [ 0. 1.1752012 3.6268604 10.017875 27.289917 ] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.sinh, x, dtype=dtype) + + +def cosh(x, dtype=None): + """ + Hyperbolic cosine, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(5).astype('float32') + >>> print(np.cosh(x)) + [ 1. 1.5430807 3.7621956 10.067662 27.308233 ] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.cosh, x, dtype=dtype) + + +def tanh(x, dtype=None): + """ + Computes hyperbolic tangent element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(5).astype('float32') + >>> print(np.tanh(x)) + [0. 0.7615942 0.9640276 0.9950548 0.9993293] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.tanh, x, dtype=dtype) + + +def arcsinh(x, dtype=None): + """ + Inverse hyperbolic sine element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(5).astype('float32') + >>> print(np.arcsinh(x)) + [0. 0.8813736 1.4436355 1.8184465 2.0947125] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.asinh, x, dtype=dtype) + + +def arccosh(x, dtype=None): + """ + Inverse hyperbolic cosine, element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.arange(1, 5).astype('float32') + >>> print(np.arccosh(x)) + [0. 1.316958 1.7627472 2.063437 ] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.acosh, x, dtype=dtype) + + +def arctanh(x, dtype=None): + """ + Inverse hyperbolic tangent element-wise. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x (Tensor): Input tensor. + dtype (:class:`mindspore.dtype`, optional): Default: :class:`None`. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([-0.99, -0.75, -0.5, 0, 0.5]).astype('float32') + >>> print(np.arctanh(x)) + [-2.646653 -0.97295505 -0.54930615 0. 0.54930615] + """ + x = _cast_type_for_trigonometric(x) + return _apply_tensor_op(F.atanh, x, dtype=dtype) + + +def arctan2(x1, x2, dtype=None): + """ + Element-wise arc tangent of :math:`x1/x2` choosing the quadrant correctly. + + Note: + Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + x1 (Tensor): input tensor. + x2 (Tensor): input tensor. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar + if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x1 = np.array([-1, +1, +1, -1]) + >>> x2 = np.array([-1, -1, +1, +1]) + >>> output = np.arctan2(x1, x2) + >>> print(output) + [-2.3561945 2.3561945 0.78539819 -0.78539819] + """ + x1 = _cast_type_for_trigonometric(x1) + x2 = _cast_type_for_trigonometric(x2) + return _apply_tensor_op(F.atan2, x1, x2, dtype=dtype) + + +def promote_types(type1, type2): + """ + Returns the data type with the smallest size and smallest scalar kind. + + Note: + The promotion rule is slightly different from original Numpy, but more like + jax, due to the preference on ``32-bit`` over ``64-bit`` data types. + + Args: + type1 (Union[:class:`mindspore.dtype`, str]): First data type. + type2 (Union[:class:`mindspore.dtype`, str]): Second data type. + + Returns: + The promoted data type. + + Raises: + TypeError: if the input are not valid :class:`mindspore.dtype` input. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> output = np.promote_types(np.float32, np.float64) + >>> print(output) + np.float64 + """ + type1 = _check_dtype(type1) + type2 = _check_dtype(type2) + return _promote(type1, type2) + + +def _apply_tensor_op(fn, *args, dtype=None): + """Applies tensor operations based on fn""" + args = _to_tensor(*args) + if isinstance(args, Tensor): + res = fn(args) + else: + res = fn(*args) + if dtype is not None and not _check_same_type(F.dtype(res), dtype): + res = F.cast(res, dtype) + return res diff --git a/mindspore/numpy/utils.py b/mindspore/numpy/utils.py index eed0c7f492..1b40324b7d 100644 --- a/mindspore/numpy/utils.py +++ b/mindspore/numpy/utils.py @@ -13,14 +13,11 @@ # limitations under the License. # ============================================================================ """internal utility functions""" - -import numpy as onp - from ..common import Tensor from ..ops import functional as F from ..common import dtype as mstype -from .utils_const import _tile_size, _add_unit_axes, _raise_type_error +from .utils_const import _tile_size, _add_unit_axes, _raise_type_error, _type_convert def _deep_list(array_like): @@ -56,9 +53,8 @@ def _deep_tensor_to_nparray(array_like): def _check_input_for_asarray(array_like): """check whether array_like argument is a valid type for np.asarray conversion""" - if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): - _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ - "or numpy.ndarray, but got ", array_like) + if not isinstance(array_like, (Tensor, list, tuple, int, float, bool)): + _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`, but got ", array_like) def _is_scalar(shape): @@ -121,6 +117,20 @@ def _convert_64_to_32(tensor): return tensor +def _to_tensor(*args): + """Returns each input as Tensor""" + res = () + for arg in args: + if isinstance(arg, (int, float, bool, list, tuple)): + arg = _convert_64_to_32(_type_convert(Tensor, arg)) + elif not isinstance(arg, Tensor): + _raise_type_error("Expect input to be array like.") + res += (arg,) + if len(res) == 1: + return res[0] + return res + + def _get_dtype_from_scalar(*input_numbers): """ Get the final dtype from series of input numbers, compared with F.typeof, we @@ -139,3 +149,8 @@ def _get_dtype_from_scalar(*input_numbers): if int_flag: return mstype.int32 return mstype.float32 + + +def _isnan(x): + """Computes isnan.""" + return F.not_equal(x, x) diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py index 872aa6399b..aa1bbaafd7 100644 --- a/mindspore/numpy/utils_const.py +++ b/mindspore/numpy/utils_const.py @@ -14,7 +14,8 @@ # ============================================================================ """internal graph-compatible utility functions""" import math -from functools import partial +from itertools import zip_longest +from collections import deque import mindspore.context as context from ..ops import functional as F @@ -24,7 +25,7 @@ from ..common import Tensor from .._c_expression import Tensor as Tensor_ from .._c_expression import typing -from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map +from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric @constexpr @@ -110,44 +111,19 @@ def _get_device(): return context.get_context('device_target') -@constexpr -def _reverse_index(idx, arr): - """ - Returns 1 if shape[idx:] is broadcastable to shape_out[idx:], - 2 situations if the function returns 1: - - 1. Tensor's shape has 1 at the designated dimension. - - 2. Tensor's dimension is less than the designated idx. (The Tensor shape - has been reversed) - For both cases, 2 tensors are broadcastable. - otherwise returns the element at position of shape - """ - if len(arr) <= idx: - return 1 - return arr[-1 - idx] - - @constexpr def _infer_out_shape(*shapes): """ - Returns shape of output after broadcasting - Raises ValueError if shape1 and shape2 cannot be broadcast + Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. """ - shapes_unbroadcastable = False - ndim_max = max(map(len, shapes)) - shape_out = [0]*ndim_max - i = 0 - for i in range(ndim_max): - shape_out[-1 - i] = max(map(partial(_reverse_index, i), shapes)) - for shape in shapes: - if _reverse_index(i, shape) != shape_out[-1 - i]: - if _reverse_index(i, shape) != 1: - shapes_unbroadcastable = True - break - if shapes_unbroadcastable: - break - if not shapes_unbroadcastable: - return tuple(shape_out) - raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') + shape_out = deque() + reversed_shapes = map(reversed, shapes) + for items in zip_longest(*reversed_shapes, fillvalue=1): + max_size = 0 if 0 in items else max(items) + if any(item not in (1, max_size) for item in items): + raise ValueError(f'operands could not be broadcast together with shapes {*shapes,}') + shape_out.appendleft(max_size) + return tuple(shape_out) @constexpr @@ -228,6 +204,21 @@ def _raise_value_error(info, param=None): raise ValueError(info + f"{param}") +@constexpr +def _raise_runtime_error(info, param=None): + """ + Raise RuntimeError in both graph/pynative mode + + Args: + info(str): info string to display + param(python obj): any object that can be recognized by graph mode. If is + not None, then param's value information will be extracted and displayed. + Default is None. + """ + if param is None: + raise RuntimeError(info) + raise RuntimeError(info + f"{param}") + @constexpr def _empty(dtype, shape): """Returns an uninitialized array with dtype and shape.""" @@ -242,6 +233,9 @@ def _promote(dtype1, dtype2): return promotion_rule[dtype1, dtype2] return promotion_rule[dtype2, dtype1] +@constexpr +def _promote_for_trigonometric(dtype): + return rule_for_trigonometric[dtype] @constexpr def _max(*args): @@ -315,7 +309,7 @@ def _canonicalize_axis(axis, ndim): axis = tuple([canonicalizer(axis) for axis in axis]) if all(axis.count(el) <= 1 for el in axis): - return axis if len(axis) > 1 else axis[0] + return tuple(sorted(axis)) if len(axis) > 1 else axis[0] raise ValueError(f"duplicate axes in {axis}.") @@ -426,13 +420,37 @@ def _tuple_getitem(tup, idx, startswith=True): @constexpr -def _iota(dtype, num): +def _tuple_setitem(tup, idx, value): + """ + Returns a tuple with specified `idx` set to `value`. + """ + tup = list(tup) + tup[idx] = value + return tuple(tup) + + +@constexpr +def _iota(dtype, num, increasing=True): """Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" # TODO: Change to P.Linspace when the kernel is implemented on CPU. - return Tensor(list(range(int(num))), dtype) + if increasing: + return Tensor(list(range(int(num))), dtype) + return Tensor(list(range(int(num)-1, -1, -1)), dtype) @constexpr def _ceil(number): """Ceils the number in graph mode.""" return math.ceil(number) + + +@constexpr +def _seq_prod(seq1, seq2): + """Returns the element-wise product of seq1 and seq2.""" + return tuple(map(lambda x, y: x*y, seq1, seq2)) + + +@constexpr +def _make_tensor(val, dtype): + """ Returns the tensor with value `val` and dtype `dtype`.""" + return Tensor(val, dtype) diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py index 0803649226..54e74b6b9f 100644 --- a/mindspore/ops/composite/multitype_ops/not_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -15,6 +15,7 @@ """Implementation for internal polymorphism `not equal` operations.""" +from . import _constexpr_utils as const_utils from ...composite import base from ... import functional as F @@ -41,6 +42,21 @@ def _not_equal_scalar(x, y): return not F.scalar_eq(x, y) +@not_equal.register("mstype", "mstype") +def _not_equal_mstype(x, y): + """ + Determine if two mindspore types are not equal. + + Args: + x (mstype): first input mindspore type. + y (mstype): second input mindspore type. + + Returns: + bool, if x != y return true, x == y return false. + """ + return not const_utils.mstype_eq(x, y) + + @not_equal.register("String", "String") def _not_equal_string(x, y): """ diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 5051936d5f..2dcd13a5bf 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -77,6 +77,7 @@ floormod = tensor_mod tensor_exp = P.Exp() exp = tensor_exp tensor_expm1 = P.Expm1() +tensor_slice = P.Slice() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() check_bprop = P.CheckBprop() @@ -94,6 +95,22 @@ tensor_slice = P.Slice() maximum = P.Maximum() minimum = P.Minimum() floor = P.Floor() +logical_not = P.LogicalNot() +logical_or = P.LogicalOr() +logical_and = P.LogicalAnd() +sin = P.Sin() +cos = P.Cos() +tan = P.Tan() +asin = P.Asin() +acos = P.ACos() +atan = P.Atan() +sinh = P.Sinh() +cosh = P.Cosh() +tanh = P.Tanh() +asinh = P.Asinh() +acosh = P.Acosh() +atanh = P.Atanh() +atan2 = P.Atan2() scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 35b093ed39..f7e72b1353 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2560,7 +2560,7 @@ class Acosh(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` Examples: >>> acosh = ops.Acosh() @@ -2637,7 +2637,7 @@ class Asinh(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` Examples: >>> asinh = ops.Asinh() diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py index 75c4754acf..16ba9be5b6 100644 --- a/tests/st/numpy_native/test_array_creations.py +++ b/tests/st/numpy_native/test_array_creations.py @@ -20,7 +20,7 @@ import numpy as onp import mindspore.numpy as mnp from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ - match_all_arrays + match_all_arrays, run_multi_test, to_tensor class Cases(): @@ -40,8 +40,8 @@ class Cases(): self.array_sets = [1, 1.1, True, [1, 0, True], [1, 1.0, 2], (1,), [(1, 2, 3), (4, 5, 6)], onp.random.random( # pylint: disable=no-member - (100, 100)).astype(onp.float32), - onp.random.random((100, 100)).astype(onp.bool)] + (100, 100)).astype(onp.float32).tolist(), + onp.random.random((100, 100)).astype(onp.bool).tolist()] self.arrs = [ rand_int(2), @@ -138,8 +138,8 @@ def test_asarray(): expected = mnp.asarray(array, test_case.mnp_dtypes[i]).asnumpy() match_array(actual, expected, error=7) - # Additional tests for nested tensor/numpy_array mixture - mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + # Additional tests for nested tensor mixture + mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] actual = onp.asarray(onp_input) @@ -168,11 +168,11 @@ def test_array(): assert arr4 is arr5 # Additional tests for nested tensor/numpy_array mixture - mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - actual = onp.asarray(onp_input) - expected = mnp.asarray(mnp_input).asnumpy() + actual = onp.array(onp_input) + expected = mnp.array(mnp_input).asnumpy() match_array(actual, expected, error=7) @@ -202,11 +202,11 @@ def test_asfarray(): match_array(actual, expected, error=7) # Additional tests for nested tensor/numpy_array mixture - mnp_input = [(onp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] + mnp_input = [(mnp.ones(3,), mnp.ones(3)), [[1, 1, 1], (1, 1, 1)]] onp_input = [(onp.ones(3,), onp.ones(3)), [[1, 1, 1], (1, 1, 1)]] - actual = onp.asarray(onp_input) - expected = mnp.asarray(mnp_input).asnumpy() + actual = onp.asfarray(onp_input) + expected = mnp.asfarray(mnp_input).asnumpy() match_array(actual, expected, error=7) @@ -373,14 +373,14 @@ def test_linspace(): stop = onp.random.random([1, 5, 1]).astype("float32") actual = onp.linspace(start, stop, num=20, retstep=True, endpoint=False, dtype=onp.float32) - expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, + expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20, retstep=True, endpoint=False) match_array(actual[0], expected[0].asnumpy(), error=6) match_array(actual[1], expected[1].asnumpy(), error=6) actual = onp.linspace(start, stop, num=20, retstep=True, endpoint=False, dtype=onp.int16) - expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, + expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20, retstep=True, endpoint=False, dtype=mnp.int16) match_array(actual[0], expected[0].asnumpy(), error=6) match_array(actual[1], expected[1].asnumpy(), error=6) @@ -388,7 +388,7 @@ def test_linspace(): for axis in range(2): actual = onp.linspace(start, stop, num=20, retstep=False, endpoint=False, dtype=onp.float32, axis=axis) - expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, + expected = mnp.linspace(to_tensor(start), to_tensor(stop), num=20, retstep=False, endpoint=False, dtype=mnp.float32, axis=axis) match_array(actual, expected.asnumpy(), error=6) @@ -510,18 +510,18 @@ def test_full_like(): for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): shape = onp.zeros_like(onp_proto).shape fill_value = rand_int() - actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() + actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy() expected = onp.full_like(onp_proto, fill_value) match_array(actual, expected) for i in range(len(shape) - 1, 0, -1): fill_value = rand_int(*shape[i:]) - actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() + actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy() expected = onp.full_like(onp_proto, fill_value) match_array(actual, expected) fill_value = rand_int(1, *shape[i + 1:]) - actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() + actual = mnp.full_like(mnp_proto, to_tensor(fill_value)).asnumpy() expected = onp.full_like(onp_proto, fill_value) match_array(actual, expected) @@ -549,6 +549,21 @@ def test_tri_triu_tril(): match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_nancumsum(): + x = rand_int(2, 3, 4, 5) + x[0][2][1][3] = onp.nan + x[1][0][2][4] = onp.nan + x[1][1][1][1] = onp.nan + match_res(mnp.nancumsum, onp.nancumsum, x) + match_res(mnp.nancumsum, onp.nancumsum, x, axis=-2) + match_res(mnp.nancumsum, onp.nancumsum, x, axis=0) + match_res(mnp.nancumsum, onp.nancumsum, x, axis=3) + + def mnp_diagonal(arr): return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) @@ -653,7 +668,7 @@ def test_meshgrid(): (2, 3), 9), onp.full((4, 5, 6), 7)) for i in range(len(xi)): arrs = xi[i:] - mnp_arrs = map(mnp.asarray, arrs) + mnp_arrs = map(to_tensor, arrs) for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)): match_all_arrays(mnp_res, onp_res) @@ -750,6 +765,68 @@ def test_ix_(): match_res(mnp_ix_, onp_ix_, *test_arrs) +def mnp_indices(): + a = mnp.indices((2, 3)) + b = mnp.indices((2, 3, 4), sparse=True) + return a, b + + +def onp_indices(): + a = onp.indices((2, 3)) + b = onp.indices((2, 3, 4), sparse=True) + return a, b + + +def test_indices(): + run_multi_test(mnp_indices, onp_indices, ()) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_geomspace(): + start = onp.arange(1, 7).reshape(2, 3) + end = [1000, 2000, 3000] + match_array(mnp.geomspace(1, 256, num=9).asnumpy(), + onp.geomspace(1, 256, num=9), error=1) + match_array(mnp.geomspace(1, 256, num=8, endpoint=False).asnumpy(), + onp.geomspace(1, 256, num=8, endpoint=False), error=1) + match_array(mnp.geomspace(to_tensor(start), end, num=4).asnumpy(), + onp.geomspace(start, end, num=4), error=1) + match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False).asnumpy(), + onp.geomspace(start, end, num=4, endpoint=False), error=1) + match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=-1).asnumpy(), + onp.geomspace(start, end, num=4, axis=-1), error=1) + match_array(mnp.geomspace(to_tensor(start), end, num=4, endpoint=False, axis=-1).asnumpy(), + onp.geomspace(start, end, num=4, endpoint=False, axis=-1), error=1) + + start = onp.arange(1, 1 + 2*3*4*5).reshape(2, 3, 4, 5) + end = [1000, 2000, 3000, 4000, 5000] + for i in range(-5, 5): + match_array(mnp.geomspace(to_tensor(start), end, num=4, axis=i).asnumpy(), + onp.geomspace(start, end, num=4, axis=i), error=1) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_vander(): + arrs = [rand_int(i + 3) for i in range(3)] + for i in range(3): + mnp_vander = mnp.vander(to_tensor(arrs[i])) + onp_vander = onp.vander(arrs[i]) + match_all_arrays(mnp_vander, onp_vander) + mnp_vander = mnp.vander(to_tensor(arrs[i]), N=2, increasing=True) + onp_vander = onp.vander(arrs[i], N=2, increasing=True) + match_all_arrays(mnp_vander, onp_vander) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training diff --git a/tests/st/numpy_native/test_array_ops.py b/tests/st/numpy_native/test_array_ops.py index 542b4b97e1..d7c5bc5b53 100644 --- a/tests/st/numpy_native/test_array_ops.py +++ b/tests/st/numpy_native/test_array_ops.py @@ -23,7 +23,7 @@ import mindspore.numpy as mnp from mindspore.nn import Cell from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \ - rand_bool, match_res, run_multi_test + rand_bool, match_res, run_multi_test, to_tensor class Cases(): @@ -139,7 +139,7 @@ def onp_transpose(input_array): @pytest.mark.env_onecard def test_transpose(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_transposed = onp_transpose(onp_array) m_transposed = mnp_transpose(mnp_array) check_all_results(o_transposed, m_transposed) @@ -170,7 +170,7 @@ def onp_expand_dims(input_array): @pytest.mark.env_onecard def test_expand_dims(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_expanded = onp_expand_dims(onp_array) m_expanded = mnp_expand_dims(mnp_array) check_all_results(o_expanded, m_expanded) @@ -205,13 +205,13 @@ def onp_squeeze(input_array): @pytest.mark.env_onecard def test_squeeze(): onp_array = onp.random.random((1, 3, 1, 4, 2)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_squeezed = onp_squeeze(onp_array) m_squeezed = mnp_squeeze(mnp_array) check_all_results(o_squeezed, m_squeezed) onp_array = onp.random.random((1, 1, 1, 1, 1)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_squeezed = onp_squeeze(onp_array) m_squeezed = mnp_squeeze(mnp_array) check_all_results(o_squeezed, m_squeezed) @@ -246,7 +246,7 @@ def onp_rollaxis(input_array): @pytest.mark.env_onecard def test_rollaxis(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_rolled = onp_rollaxis(onp_array) m_rolled = mnp_rollaxis(mnp_array) check_all_results(o_rolled, m_rolled) @@ -281,7 +281,7 @@ def onp_swapaxes(input_array): @pytest.mark.env_onecard def test_swapaxes(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_swaped = onp_swapaxes(onp_array) m_swaped = mnp_swapaxes(mnp_array) check_all_results(o_swaped, m_swaped) @@ -324,7 +324,7 @@ def onp_reshape(input_array): @pytest.mark.env_onecard def test_reshape(): onp_array = onp.random.random((2, 3, 4)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_reshaped = onp_reshape(onp_array) m_reshaped = mnp_reshape(mnp_array) check_all_results(o_reshaped, m_reshaped) @@ -349,7 +349,7 @@ def onp_ravel(input_array): @pytest.mark.env_onecard def test_ravel(): onp_array = onp.random.random((2, 3, 4)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_ravel = onp_ravel(onp_array) m_ravel = mnp_ravel(mnp_array).asnumpy() match_array(o_ravel, m_ravel) @@ -380,7 +380,7 @@ def onp_concatenate(input_array): @pytest.mark.env_onecard def test_concatenate(): onp_array = onp.random.random((5, 4, 3, 2)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_concatenate = onp_concatenate(onp_array) m_concatenate = mnp_concatenate(mnp_array) check_all_results(o_concatenate, m_concatenate) @@ -407,8 +407,8 @@ def onp_append(arr1, arr2): def test_append(): onp_array = onp.random.random((4, 3, 2)).astype('float32') onp_value = onp.random.random((4, 3, 2)).astype('float32') - mnp_array = mnp.asarray(onp_array) - mnp_value = mnp.asarray(onp_value) + mnp_array = to_tensor(onp_array) + mnp_value = to_tensor(onp_value) onp_res = onp_append(onp_array, onp_value) mnp_res = mnp_append(mnp_array, mnp_value) check_all_results(onp_res, mnp_res) @@ -424,13 +424,13 @@ def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5): onp_array1 = onp.random.randint( low=low, high=high, size=shape).astype(onp.float32) onp_array_lst.append(onp_array1) - mnp_array_lst.append(mnp.asarray(onp_array1)) + mnp_array_lst.append(to_tensor(onp_array1)) if axis is not None and axis < ndim: new_shape[axis] += onp.random.randint(2) onp_array2 = onp.random.randint( low=low, high=high, size=new_shape).astype(onp.float32) onp_array_lst.append(onp_array2) - mnp_array_lst.append(mnp.asarray(onp_array2)) + mnp_array_lst.append(to_tensor(onp_array2)) return onp_array_lst, mnp_array_lst # Test np.xstack @@ -656,7 +656,7 @@ def onp_ndarray_flatten(input_array): @pytest.mark.env_onecard def test_ndarray_flatten(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_flatten = onp_ndarray_flatten(onp_array) m_flatten = mnp_ndarray_flatten(mnp_array) check_all_results(o_flatten, m_flatten) @@ -687,7 +687,7 @@ def onp_ndarray_transpose(input_array): @pytest.mark.env_onecard def test_ndarray_transpose(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_transposed = onp_ndarray_transpose(onp_array) m_transposed = mnp_ndarray_transpose(mnp_array) check_all_results(o_transposed, m_transposed) @@ -716,7 +716,7 @@ def onp_ndarray_astype(input_array): @pytest.mark.env_onecard def test_ndarray_astype(): onp_array = onp.random.random((3, 4, 5)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) o_astype = onp_ndarray_astype(onp_array) m_astype = mnp_ndarray_astype(mnp_array) for arr1, arr2 in zip(o_astype, m_astype): @@ -747,7 +747,7 @@ def mnp_concatenate_type_promotion(mnp_array1, mnp_array2, mnp_array3, mnp_array @pytest.mark.env_onecard def test_concatenate_type_promotion(): onp_array = onp.random.random((5, 1)).astype('float32') - mnp_array = mnp.asarray(onp_array) + mnp_array = to_tensor(onp_array) onp_array1 = onp_array.astype(onp.float16) onp_array2 = onp_array.astype(onp.bool_) onp_array3 = onp_array.astype(onp.float32) @@ -1049,7 +1049,7 @@ def test_split(): onp_arrs = [ onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32') ] - mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + mnp_arrs = [to_tensor(arr) for arr in onp_arrs] for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): o_split = onp_split(onp_arr) m_split = mnp_split(mnp_arr) @@ -1058,6 +1058,36 @@ def test_split(): match_array(expect, actual.asnumpy()) +def mnp_array_split(input_tensor): + a = mnp.array_split(input_tensor, indices_or_sections=4, axis=2) + b = mnp.array_split(input_tensor, indices_or_sections=3, axis=1) + c = mnp.array_split(input_tensor, indices_or_sections=6) + return a, b, c + + +def onp_array_split(input_array): + a = onp.array_split(input_array, indices_or_sections=4, axis=2) + b = onp.array_split(input_array, indices_or_sections=3, axis=1) + c = onp.array_split(input_array, indices_or_sections=6) + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_array_split(): + onp_arr = onp.random.randint(1, 5, size=(9, 7, 13)).astype('float32') + mnp_arr = to_tensor(onp_arr) + o_split = onp_split(onp_arr) + m_split = mnp_split(mnp_arr) + for expect_lst, actual_lst in zip(o_split, m_split): + for expect, actual in zip(expect_lst, actual_lst): + match_array(expect, actual.asnumpy()) + + def mnp_vsplit(input_tensor): a = mnp.vsplit(input_tensor, indices_or_sections=3) b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) @@ -1082,7 +1112,7 @@ def test_vsplit(): onp_arrs = [ onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32') ] - mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + mnp_arrs = [to_tensor(arr) for arr in onp_arrs] for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): o_vsplit = onp_vsplit(onp_arr) m_vsplit = mnp_vsplit(mnp_arr) @@ -1115,7 +1145,7 @@ def test_hsplit(): onp_arrs = [ onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32') ] - mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + mnp_arrs = [to_tensor(arr) for arr in onp_arrs] for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): o_hsplit = onp_hsplit(onp_arr) m_hsplit = mnp_hsplit(mnp_arr) @@ -1148,7 +1178,7 @@ def test_dsplit(): onp_arrs = [ onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32') ] - mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + mnp_arrs = [to_tensor(arr) for arr in onp_arrs] for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): o_dsplit = onp_dsplit(onp_arr) m_dsplit = mnp_dsplit(mnp_arr) @@ -1248,6 +1278,29 @@ def test_repeat(): run_multi_test(mnp_repeat, onp_repeat, (x,)) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_select(): + choicelist = rand_int(2, 3, 4, 5) + condlist = choicelist > 2 + match_res(mnp.select, onp.select, condlist, choicelist) + match_res(mnp.select, onp.select, condlist, choicelist, default=10) + + condlist = rand_bool(5, 4, 1, 3) + choicelist = rand_int(5, 3) + match_res(mnp.select, onp.select, condlist, choicelist) + match_res(mnp.select, onp.select, condlist, choicelist, default=10) + + condlist = rand_bool(3, 1, 7) + choicelist = rand_int(3, 5, 2, 1) + match_res(mnp.select, onp.select, condlist, choicelist) + match_res(mnp.select, onp.select, condlist, choicelist, default=10) + + class ReshapeExpandSqueeze(Cell): def __init__(self): super(ReshapeExpandSqueeze, self).__init__() @@ -1333,7 +1386,7 @@ def test_swapaxes_exception(): @pytest.mark.env_onecard def test_tensor_flatten(): lst = [[1.0, 2.0], [3.0, 4.0]] - tensor_list = mnp.asarray(lst) + tensor_list = to_tensor(lst) assert tensor_list.flatten().asnumpy().tolist() == [1.0, 2.0, 3.0, 4.0] assert tensor_list.flatten(order='F').asnumpy().tolist() == [ 1.0, 3.0, 2.0, 4.0] @@ -1347,7 +1400,7 @@ def test_tensor_flatten(): @pytest.mark.env_onecard def test_tensor_reshape(): lst = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] - tensor_list = mnp.asarray(lst) + tensor_list = to_tensor(lst) with pytest.raises(TypeError): tensor_list = tensor_list.reshape({0, 1, 2}) with pytest.raises(ValueError): @@ -1364,7 +1417,7 @@ def test_tensor_reshape(): @pytest.mark.env_onecard def test_tensor_squeeze(): lst = [[[1.0], [2.0], [3.0]]] - tensor_list = mnp.asarray(lst) + tensor_list = to_tensor(lst) with pytest.raises(TypeError): tensor_list = tensor_list.squeeze(1.2) with pytest.raises(ValueError): @@ -1381,7 +1434,7 @@ def test_tensor_squeeze(): @pytest.mark.env_onecard def test_tensor_ravel(): lst = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]] - tensor_list = mnp.asarray(lst) + tensor_list = to_tensor(lst) assert tensor_list.ravel().shape == (8,) assert tensor_list.ravel().asnumpy().tolist() == [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] @@ -1395,9 +1448,47 @@ def test_tensor_ravel(): @pytest.mark.env_onecard def test_tensor_swapaxes(): lst = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - tensor_list = mnp.asarray(lst) + tensor_list = to_tensor(lst) with pytest.raises(TypeError): tensor_list = tensor_list.swapaxes(0, (1,)) with pytest.raises(ValueError): tensor_list = tensor_list.swapaxes(0, 3) assert tensor_list.swapaxes(0, 1).shape == (3, 2) + + +def mnp_rot90(input_tensor): + a = mnp.rot90(input_tensor) + b = mnp.rot90(input_tensor, 2) + c = mnp.rot90(input_tensor, 3) + d = mnp.rot90(input_tensor, 4) + e = mnp.rot90(input_tensor, 5, (0, -1)) + f = mnp.rot90(input_tensor, 1, (2, 0)) + g = mnp.rot90(input_tensor, -3, (-1, -2)) + h = mnp.rot90(input_tensor, 3, (2, 1)) + return a, b, c, d, e, f, g, h + + +def onp_rot90(input_array): + a = onp.rot90(input_array) + b = onp.rot90(input_array, 2) + c = onp.rot90(input_array, 3) + d = onp.rot90(input_array, 4) + e = onp.rot90(input_array, 5, (0, -1)) + f = onp.rot90(input_array, 1, (2, 0)) + g = onp.rot90(input_array, -3, (-1, -2)) + h = onp.rot90(input_array, 3, (2, 1)) + return a, b, c, d, e, f, g, h + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_rot90(): + onp_array = rand_int(3, 4, 5).astype('float32') + mnp_array = to_tensor(onp_array) + o_rot = onp_rot90(onp_array) + m_rot = mnp_rot90(mnp_array) + check_all_results(o_rot, m_rot) diff --git a/tests/st/numpy_native/test_logic_ops.py b/tests/st/numpy_native/test_logic_ops.py index 29e4e775c4..3d0942ade8 100644 --- a/tests/st/numpy_native/test_logic_ops.py +++ b/tests/st/numpy_native/test_logic_ops.py @@ -19,7 +19,8 @@ import numpy as onp import mindspore.numpy as mnp -from .utils import rand_int, run_binop_test, match_res +from .utils import rand_int, rand_bool, run_binop_test, run_logical_test, match_res, \ + match_all_arrays, to_tensor class Cases(): @@ -55,6 +56,15 @@ class Cases(): rand_int(8, 1, 6, 1) ] + # Boolean arrays + self.boolean_arrs = [ + rand_bool(), + rand_bool(5), + rand_bool(6, 1), + rand_bool(7, 1, 5), + rand_bool(8, 1, 6, 1) + ] + # array which contains infs and nans self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]]) @@ -246,10 +256,147 @@ def test_isneginf(): match_res(mnp_isneginf, onp_isneginf, test_case.infs) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard def test_isscalar(): assert mnp.isscalar(1) == onp.isscalar(1) assert mnp.isscalar(2.3) == onp.isscalar(2.3) assert mnp.isscalar([4.5]) == onp.isscalar([4.5]) assert mnp.isscalar(False) == onp.isscalar(False) - assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True)) + assert mnp.isscalar(to_tensor(True)) == onp.isscalar(onp.array(True)) assert mnp.isscalar('numpy') == onp.isscalar('numpy') + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isclose(): + a = [0, 1, 2, float('inf'), float('inf'), float('nan')] + b = [0, 1, -2, float('-inf'), float('inf'), float('nan')] + match_all_arrays(mnp.isclose(a, b), onp.isclose(a, b)) + match_all_arrays(mnp.isclose(a, b, equal_nan=True), onp.isclose(a, b, equal_nan=True)) + + a = rand_int(2, 3, 4, 5) + diff = (onp.random.random((2, 3, 4, 5)).astype("float32") - 0.5) / 1000 + b = a + diff + match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3), onp.isclose(a, b, atol=1e-3)) + match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-3, rtol=1e-4), + onp.isclose(a, b, atol=1e-3, rtol=1e-4)) + match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b), atol=1e-2, rtol=1e-6), + onp.isclose(a, b, atol=1e-2, rtol=1e-6)) + + a = rand_int(2, 3, 4, 5) + b = rand_int(4, 5) + match_all_arrays(mnp.isclose(to_tensor(a), to_tensor(b)), onp.isclose(a, b)) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_in1d(): + xi = [rand_int(), rand_int(1), rand_int(10)] + yi = [rand_int(), rand_int(1), rand_int(10)] + for x in xi: + for y in yi: + match_res(mnp.in1d, onp.in1d, x, y) + match_res(mnp.in1d, onp.in1d, x, y, invert=True) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isin(): + xi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)] + yi = [rand_int(), rand_int(1), rand_int(10), rand_int(2, 3)] + for x in xi: + for y in yi: + match_res(mnp.in1d, onp.in1d, x, y) + match_res(mnp.in1d, onp.in1d, x, y, invert=True) + + +def mnp_logical_or(x1, x2): + return mnp.logical_or(x1, x2) + + +def onp_logical_or(x1, x2): + return onp.logical_or(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logical_or(): + run_logical_test(mnp_logical_or, onp_logical_or, test_case) + + +def mnp_logical_xor(x1, x2): + return mnp.logical_xor(x1, x2) + + +def onp_logical_xor(x1, x2): + return onp.logical_xor(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logical_xor(): + run_logical_test(mnp_logical_xor, onp_logical_xor, test_case) + + +def mnp_logical_and(x1, x2): + return mnp.logical_and(x1, x2) + + +def onp_logical_and(x1, x2): + return onp.logical_and(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logical_and(): + run_logical_test(mnp_logical_and, onp_logical_and, test_case) + + +def mnp_logical_not(x): + return mnp.logical_not(x) + + +def onp_logical_not(x): + return onp.logical_not(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logical_not(): + for arr in test_case.boolean_arrs: + expected = onp_logical_not(arr) + actual = mnp_logical_not(to_tensor(arr)) + onp.testing.assert_equal(actual.asnumpy().tolist(), expected.tolist()) diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py index 1fc2b86e06..7ba20ee2a4 100644 --- a/tests/st/numpy_native/test_math_ops.py +++ b/tests/st/numpy_native/test_math_ops.py @@ -20,7 +20,7 @@ import numpy as onp import mindspore.numpy as mnp from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \ - run_single_test, match_res, match_array, match_meta + run_single_test, match_res, match_array, match_meta, match_all_arrays, to_tensor class Cases(): def __init__(self): @@ -253,29 +253,6 @@ def test_minimum(): run_binop_test(mnp_minimum, onp_minimum, test_case) -def mnp_add_kwargs(x, y, where=None, out=None): - return mnp.add(x, y, where=where, out=out) - - -def onp_add_kwargs(x, y, where=None, out=None): - return onp.add(x, y, where=where, out=out) - - -@pytest.mark.level1 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_x86_gpu_training -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_add_kwargs(): - for where in test_case.bool_broadcastables[:2]: - for x in test_case.broadcastables[:2]: - for y in test_case.broadcastables[:2]: - shape_out = onp.broadcast(where, x, y).shape - out = rand_int(*shape_out) - match_res(mnp_add_kwargs, onp_add_kwargs, x, y, where, out) - - def mnp_tensordot(x, y): a = mnp.tensordot(x, y) b = mnp.tensordot(x, y, axes=0) @@ -351,21 +328,64 @@ def test_std(): run_single_test(mnp_std, onp_std, arr2, error=1e-5) +def mnp_nanstd(x): + a = mnp.nanstd(x) + b = mnp.nanstd(x, axis=None) + c = mnp.nanstd(x, axis=0) + d = mnp.nanstd(x, axis=1) + e = mnp.nanstd(x, axis=(-1, 1)) + f = mnp.nanstd(x, axis=(0, 1, 2)) + g = mnp.nanstd(x, axis=None, ddof=1, keepdims=True) + h = mnp.nanstd(x, axis=0, ddof=1, keepdims=True) + i = mnp.nanstd(x, axis=(2), ddof=1, keepdims=True) + return a, b, c, d, e, f, g, h, i + + +def onp_nanstd(x): + a = onp.nanstd(x) + b = onp.nanstd(x, axis=None) + c = onp.nanstd(x, axis=0) + d = onp.nanstd(x, axis=1) + e = onp.nanstd(x, axis=(-1, 1)) + f = onp.nanstd(x, axis=(0, 1, 2)) + g = onp.nanstd(x, axis=None, ddof=1, keepdims=True) + h = onp.nanstd(x, axis=0, ddof=1, keepdims=True) + i = onp.nanstd(x, axis=(2), ddof=1, keepdims=True) + return a, b, c, d, e, f, g, h, i + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_nanstd(): + arr1 = rand_int(2, 3, 4, 5) + arr1[0][2][1][3] = onp.nan + arr1[1][0][2][4] = onp.nan + arr1[1][1][1][1] = onp.nan + arr2 = rand_int(4, 5, 4, 3, 3) + arr2[3][1][2][1][0] = onp.nan + arr2[1][1][1][1][1] = onp.nan + arr2[0][4][3][0][2] = onp.nan + run_single_test(mnp_nanstd, onp_nanstd, arr1, error=1e-5) + run_single_test(mnp_nanstd, onp_nanstd, arr2, error=1e-5) + + def mnp_var(x): - a = mnp.std(x) - b = mnp.std(x, axis=0) - c = mnp.std(x, axis=(0)) - d = mnp.std(x, axis=(0, 1, 2)) - e = mnp.std(x, axis=(-1, 1, 2), ddof=1, keepdims=True) + a = mnp.var(x) + b = mnp.var(x, axis=0) + c = mnp.var(x, axis=(0)) + d = mnp.var(x, axis=(0, 1, 2)) + e = mnp.var(x, axis=(-1, 1, 2), ddof=1, keepdims=True) return a, b, c, d, e def onp_var(x): - a = onp.std(x) - b = onp.std(x, axis=0) - c = onp.std(x, axis=(0)) - d = onp.std(x, axis=(0, 1, 2)) - e = onp.std(x, axis=(-1, 1, 2), ddof=1, keepdims=True) + a = onp.var(x) + b = onp.var(x, axis=0) + c = onp.var(x, axis=(0)) + d = onp.var(x, axis=(0, 1, 2)) + e = onp.var(x, axis=(-1, 1, 2), ddof=1, keepdims=True) return a, b, c, d, e @@ -382,6 +402,41 @@ def test_var(): run_single_test(mnp_var, onp_var, arr2, error=1e-5) +def mnp_nanvar(x): + a = mnp.var(x) + b = mnp.var(x, axis=0) + c = mnp.var(x, axis=(0)) + d = mnp.var(x, axis=(0, 1, 2)) + e = mnp.var(x, axis=(-1, 1, 2), ddof=1, keepdims=True) + return a, b, c, d, e + + +def onp_nanvar(x): + a = onp.var(x) + b = onp.var(x, axis=0) + c = onp.var(x, axis=(0)) + d = onp.var(x, axis=(0, 1, 2)) + e = onp.var(x, axis=(-1, 1, 2), ddof=1, keepdims=True) + return a, b, c, d, e + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_nanvar(): + arr1 = rand_int(2, 3, 4, 5) + arr1[0][2][1][3] = onp.nan + arr1[1][0][2][4] = onp.nan + arr1[1][1][1][1] = onp.nan + arr2 = rand_int(4, 5, 4, 3, 3) + arr2[3][1][2][1][0] = onp.nan + arr2[1][1][1][1][1] = onp.nan + arr2[0][4][3][0][2] = onp.nan + run_single_test(mnp_nanvar, onp_nanvar, arr1, error=1e-5) + run_single_test(mnp_nanvar, onp_nanvar, arr2, error=1e-5) + + def mnp_average(x): a = mnp.average(x) b = mnp.average(x, axis=None) @@ -544,9 +599,9 @@ def test_type_promotion(): arr = rand_int(2, 3) onp_sum = onp_add(arr, arr) - a = mnp.asarray(arr, dtype='float16') - b = mnp.asarray(arr, dtype='float32') - c = mnp.asarray(arr, dtype='int32') + a = to_tensor(arr, dtype=mnp.float16) + b = to_tensor(arr, dtype=mnp.float32) + c = to_tensor(arr, dtype=mnp.int32) match_array(mnp_add(a, b).asnumpy(), onp_sum) match_array(mnp_add(b, c).asnumpy(), onp_sum) @@ -569,21 +624,16 @@ def onp_absolute(x): def test_absolute(): arr = rand_int(2, 3) - a = mnp.asarray(arr, dtype='float16') - b = mnp.asarray(arr, dtype='float32') - c = mnp.asarray(arr, dtype='uint8') - d = mnp.asarray(arr, dtype='bool') + a = to_tensor(arr, dtype=mnp.float16) + b = to_tensor(arr, dtype=mnp.float32) + c = to_tensor(arr, dtype=mnp.uint8) + d = to_tensor(arr, dtype=mnp.bool_) match_array(mnp_absolute(a).asnumpy(), onp_absolute(a.asnumpy())) match_array(mnp_absolute(b).asnumpy(), onp_absolute(b.asnumpy())) match_array(mnp_absolute(c).asnumpy(), onp_absolute(c.asnumpy())) match_array(mnp_absolute(d).asnumpy(), onp_absolute(d.asnumpy())) - where = rand_int(2, 3).astype('bool') - out = rand_int(2, 3) - match_array(mnp.absolute(a, out=mnp.asarray(out), where=mnp.asarray(where)).asnumpy(), - onp.absolute(a.asnumpy(), out=out, where=where)) - @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @@ -626,20 +676,12 @@ def test_ptp(): match_res(mnp_ptp, onp_ptp, arr) -def mnp_add_dtype(x1, x2, out, where): - a = mnp.add(x1, x2, dtype=mnp.float16) - b = mnp.add(x1, x2, out=out, dtype=mnp.float16) - c = mnp.add(x1, x2, where=where, dtype=mnp.float16) - d = mnp.add(x1, x2, out=out, where=where, dtype=mnp.float16) - return a, b, c, d +def mnp_add_dtype(x1, x2): + return mnp.add(x1, x2, dtype=mnp.float16) -def onp_add_dtype(x1, x2, out, where): - a = onp.add(x1, x2, dtype=onp.float16) - b = onp.add(x1, x2, out=out, dtype=onp.float16) - c = onp.add(x1, x2, where=where, dtype=onp.float16) - d = onp.add(x1, x2, out=out, where=where, dtype=onp.float16) - return a, b, c, d +def onp_add_dtype(x1, x2): + return onp.add(x1, x2, dtype=onp.float16) @pytest.mark.level1 @@ -651,10 +693,8 @@ def onp_add_dtype(x1, x2, out, where): def test_add_dtype(): x1 = rand_int(2, 3).astype('int32') x2 = rand_int(2, 3).astype('int32') - out = rand_int(2, 3).astype('float32') - where = rand_bool(2, 3) - arrs = (x1, x2, out, where) - mnp_arrs = map(mnp.array, arrs) + arrs = (x1, x2) + mnp_arrs = map(to_tensor, arrs) mnp_res = mnp_add_dtype(*mnp_arrs) onp_res = onp_add_dtype(*arrs) for actual, expected in zip(mnp_res, onp_res): @@ -758,6 +798,116 @@ def test_log(): run_unary_test(mnp.log, onp.log, test_case, error=1e-5) +def mnp_log1p(x): + return mnp.log1p(x) + + +def onp_log1p(x): + return onp.log1p(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_log1p(): + run_unary_test(mnp_log1p, onp_log1p, test_case, error=1e-5) + + +def mnp_logaddexp(x1, x2): + return mnp.logaddexp(x1, x2) + + +def onp_logaddexp(x1, x2): + return onp.logaddexp(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logaddexp(): + test_cases = [ + onp.random.randint(1, 5, (2)).astype('float16'), + onp.random.randint(1, 5, (3, 2)).astype('float16'), + onp.random.randint(1, 5, (1, 3, 2)).astype('float16'), + onp.random.randint(1, 5, (5, 6, 3, 2)).astype('float16')] + for _, x1 in enumerate(test_cases): + for _, x2 in enumerate(test_cases): + expected = onp_logaddexp(x1, x2) + actual = mnp_logaddexp(to_tensor(x1), to_tensor(x2)) + onp.testing.assert_almost_equal(actual.asnumpy().tolist(), expected.tolist(), + decimal=2) + + +def mnp_log2(x): + return mnp.log2(x) + + +def onp_log2(x): + return onp.log2(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_log2(): + run_unary_test(mnp_log2, onp_log2, test_case, error=1e-5) + + +def mnp_logaddexp2(x1, x2): + return mnp.logaddexp2(x1, x2) + + +def onp_logaddexp2(x1, x2): + return onp.logaddexp2(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_logaddexp2(): + test_cases = [ + onp.random.randint(1, 5, (2)).astype('float16'), + onp.random.randint(1, 5, (3, 2)).astype('float16'), + onp.random.randint(1, 5, (1, 3, 2)).astype('float16'), + onp.random.randint(1, 5, (5, 6, 3, 2)).astype('float16')] + for _, x1 in enumerate(test_cases): + for _, x2 in enumerate(test_cases): + expected = onp_logaddexp2(x1, x2) + actual = mnp_logaddexp2(to_tensor(x1), to_tensor(x2)) + onp.testing.assert_almost_equal(actual.asnumpy().tolist(), expected.tolist(), + decimal=2) + + +def mnp_log10(x): + return mnp.log10(x) + + +def onp_log10(x): + return onp.log10(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_log10(): + run_unary_test(mnp_log10, onp_log10, test_case, error=1e-5) + + def mnp_maximum(x1, x2): return mnp.maximum(x1, x2) @@ -777,9 +927,9 @@ def test_maximum(): def mnp_clip(x): - a = mnp.clip(x, mnp.asarray(10.0), mnp.asarray([2,])) + a = mnp.clip(x, to_tensor(10.0), to_tensor([2,])) b = mnp.clip(x, 0, 1) - c = mnp.clip(x, mnp.asarray(0), mnp.asarray(10), dtype=mnp.float32) + c = mnp.clip(x, to_tensor(0), to_tensor(10), dtype=mnp.float32) return a, b, c @@ -842,8 +992,8 @@ def mnp_amin(x, mask): c = mnp.amin(x, keepdims=True) d = mnp.amin(x, initial=-1) e = mnp.amin(x, axis=(0, 1), keepdims=True) - f = mnp.amin(x, initial=-2, where=mask) - g = mnp.amin(x, initial=-3, where=mask, keepdims=True) + f = mnp.amin(x, initial=-2) + g = mnp.amin(x, initial=-3, keepdims=True) h = mnp.amin(x, axis=(1, 2, 3), initial=-4, where=mask) return a, b, c, d, e, f, g, h @@ -854,8 +1004,8 @@ def onp_amin(x, mask): c = onp.amin(x, keepdims=True) d = onp.amin(x, initial=-1) e = onp.amin(x, axis=(0, 1), keepdims=True) - f = onp.amin(x, initial=-2, where=mask) - g = onp.amin(x, initial=-3, where=mask, keepdims=True) + f = onp.amin(x, initial=-2) + g = onp.amin(x, initial=-3, keepdims=True) h = onp.amin(x, axis=(1, 2, 3), initial=-4, where=mask) return a, b, c, d, e, f, g, h @@ -1088,12 +1238,84 @@ def test_expm1(): run_unary_test(mnp_expm1, onp_expm1, test_case, error=5) -def mnp_positive(x, out, where): - return mnp.positive(x, out=out, where=where) +def mnp_exp2(x): + return mnp.exp2(x) + + +def onp_exp2(x): + return onp.exp2(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_exp2(): + run_unary_test(mnp_exp2, onp_exp2, test_case, error=5) + + +def mnp_kron(x, y): + return mnp.kron(x, y) -def onp_positive(x, out, where): - return onp.positive(x, out=out, where=where) +def onp_kron(x, y): + return onp.kron(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_kron(): + run_binop_test(mnp_kron, onp_kron, test_case) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cross(): + x = onp.arange(8).reshape(2, 2, 1, 2) + y = onp.arange(4).reshape(1, 2, 2) + match_res(mnp.cross, onp.cross, x, y) + match_res(mnp.cross, onp.cross, x, y, axisa=-3, axisb=1, axisc=2) + match_res(mnp.cross, onp.cross, x, y, axisa=-3, axisb=1, axisc=2, axis=1) + x = onp.arange(18).reshape(2, 3, 1, 3) + y = onp.arange(9).reshape(1, 3, 3) + match_res(mnp.cross, onp.cross, x, y) + match_res(mnp.cross, onp.cross, x, y, axisa=-3, axisb=1, axisc=2) + match_res(mnp.cross, onp.cross, x, y, axisa=-3, axisb=1, axisc=2, axis=1) + + +def mnp_ceil(x): + return mnp.ceil(x) + + +def onp_ceil(x): + return onp.ceil(x) + + +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ceil(): + run_unary_test(mnp_ceil, onp_ceil, test_case) + + +def mnp_positive(x): + return mnp.positive(x) + + +def onp_positive(x): + return onp.positive(x) @pytest.mark.level1 @@ -1104,21 +1326,17 @@ def onp_positive(x, out, where): @pytest.mark.env_onecard def test_positive(): arr = onp.arange(-6, 6).reshape((2, 2, 3)).astype('float32') - out_lst = [onp.ones((2, 2, 3)).astype('float32'), onp.ones((5, 2, 2, 3)).astype('float32')] - where_lst = [onp.full((2, 2, 3), [True, False, True]), onp.full((2, 3), False)] - for out in out_lst: - for where in where_lst: - onp_pos = onp_positive(arr, out=out, where=where) - mnp_pos = mnp_positive(mnp.asarray(arr), mnp.asarray(out), mnp.asarray(where)) - match_array(mnp_pos.asnumpy(), onp_pos) + onp_pos = onp_positive(arr) + mnp_pos = mnp_positive(to_tensor(arr)) + match_array(mnp_pos.asnumpy(), onp_pos) -def mnp_negative(x, out, where): - return mnp.negative(x, out=out, where=where) +def mnp_negative(x): + return mnp.negative(x) -def onp_negative(x, out, where): - return onp.negative(x, out=out, where=where) +def onp_negative(x): + return onp.negative(x) @pytest.mark.level1 @@ -1129,13 +1347,9 @@ def onp_negative(x, out, where): @pytest.mark.env_onecard def test_negative(): arr = onp.arange(-6, 6).reshape((2, 2, 3)).astype('float32') - out_lst = [onp.ones((2, 2, 3)).astype('float32'), onp.ones((5, 2, 2, 3)).astype('float32')] - where_lst = [onp.full((2, 2, 3), [True, False, True]), onp.full((2, 3), False)] - for out in out_lst: - for where in where_lst: - onp_neg = onp_negative(arr, out=out, where=where) - mnp_neg = mnp_negative(mnp.asarray(arr), mnp.asarray(out), mnp.asarray(where)) - match_array(mnp_neg.asnumpy(), onp_neg, 1e-5) + onp_neg = onp_negative(arr) + mnp_neg = mnp_negative(to_tensor(arr)) + match_array(mnp_neg.asnumpy(), onp_neg, 1e-5) @pytest.mark.level1 @@ -1152,12 +1366,552 @@ def test_cumsum(): match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) x = rand_int(3, 4, 5) - match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), + match_array(mnp.cumsum(to_tensor(x), dtype="bool").asnumpy(), onp.cumsum(x, dtype="bool")) - match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), + match_array(mnp.cumsum(to_tensor(x), axis=-1).asnumpy(), onp.cumsum(x, axis=-1)) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_promote_types(): + assert mnp.promote_types(mnp.int32, mnp.bool_) == mnp.int32 + assert mnp.promote_types(int, mnp.bool_) == mnp.int32 + assert mnp.promote_types("float32", mnp.int64) == mnp.float32 + assert mnp.promote_types(mnp.int64, mnp.float16) == mnp.float16 + assert mnp.promote_types(int, float) == mnp.float32 + + +def mnp_diff(input_tensor): + a = mnp.diff(input_tensor, 2, append=3.0) + b = mnp.diff(input_tensor, 4, prepend=6, axis=-2) + c = mnp.diff(input_tensor, 0, append=3.0, axis=-1) + d = mnp.diff(input_tensor, 10, prepend=6) + e = mnp.diff(input_tensor, 1, prepend=input_tensor) + f = mnp.ediff1d(input_tensor, to_end=input_tensor) + g = mnp.ediff1d(input_tensor) + h = mnp.ediff1d(input_tensor, to_begin=3) + return a, b, c, d, e, f, g, h + + +def onp_diff(input_array): + a = onp.diff(input_array, 2, append=3.0) + b = onp.diff(input_array, 4, prepend=6, axis=-2) + c = onp.diff(input_array, 0, append=3.0, axis=-1) + d = onp.diff(input_array, 10, prepend=6) + e = onp.diff(input_array, 1, prepend=input_array) + f = onp.ediff1d(input_array, to_end=input_array) + g = onp.ediff1d(input_array) + h = onp.ediff1d(input_array, to_begin=3) + return a, b, c, d, e, f, g, h + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_diff(): + arr = rand_int(3, 4, 5) + match_res(mnp_diff, onp_diff, arr) + arr = rand_int(1, 4, 6, 3) + match_res(mnp_diff, onp_diff, arr) + + +def mnp_sin(x): + return mnp.sin(x) + + +def onp_sin(x): + return onp.sin(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_sin(): + arr = onp.random.rand(2, 3, 4).astype('float32') + expect = onp_sin(arr) + actual = mnp_sin(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_cos(x): + return mnp.cos(x) + + +def onp_cos(x): + return onp.cos(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cos(): + arr = onp.random.rand(2, 3, 4).astype('float32') + expect = onp_cos(arr) + actual = mnp_cos(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_tan(x): + return mnp.tan(x) + + +def onp_tan(x): + return onp.tan(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tan(): + arr = onp.array([-0.75, -0.5, 0, 0.5, 0.75]).astype('float32') + expect = onp_tan(arr) + actual = mnp_tan(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_arcsin(x): + return mnp.arcsin(x) + + +def onp_arcsin(x): + return onp.arcsin(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arcsin(): + arr = onp.random.uniform(-1, 1, 12).astype('float32') + onp_asin = onp_arcsin(arr) + mnp_asin = mnp_arcsin(to_tensor(arr)) + match_array(mnp_asin.asnumpy(), onp_asin, error=5) + + +def mnp_arccos(x): + return mnp.arccos(x) + + +def onp_arccos(x): + return onp.arccos(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arccos(): + arr = onp.random.uniform(-1, 1, 12).astype('float32') + onp_acos = onp_arccos(arr) + mnp_acos = mnp_arccos(to_tensor(arr)) + match_array(mnp_acos.asnumpy(), onp_acos, error=5) + + +def mnp_arctan(x): + return mnp.arctan(x) + + +def onp_arctan(x): + return onp.arctan(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arctan(): + arr = onp.random.uniform(-1, 1, 12).astype('float32') + onp_atan = onp_arctan(arr) + mnp_atan = mnp_arctan(to_tensor(arr)) + match_array(mnp_atan.asnumpy(), onp_atan, error=5) + + +def mnp_sinh(x): + return mnp.sinh(x) + + +def onp_sinh(x): + return onp.sinh(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_sinh(): + arr = onp.random.rand(2, 3, 4).astype('float32') + expect = onp_sinh(arr) + actual = mnp_sinh(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_cosh(x): + return mnp.cosh(x) + + +def onp_cosh(x): + return onp.cosh(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cosh(): + arr = onp.random.rand(2, 3, 4).astype('float32') + expect = onp_cosh(arr) + actual = mnp_cosh(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_tanh(x): + return mnp.tanh(x) + + +def onp_tanh(x): + return onp.tanh(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tanh(): + arr = onp.random.rand(2, 3, 4).astype('float32') + expect = onp_tanh(arr) + actual = mnp_tanh(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_arcsinh(x): + return mnp.arcsinh(x) + + +def onp_arcsinh(x): + return onp.arcsinh(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arcsinh(): + arr = onp.random.rand(2, 3, 4).astype('float32') + expect = onp_arcsinh(arr) + actual = mnp_arcsinh(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_arccosh(x): + return mnp.arccosh(x) + + +def onp_arccosh(x): + return onp.arccosh(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arccosh(): + arr = onp.random.randint(1, 100, size=(2, 3)).astype('float32') + expect = onp_arccosh(arr) + actual = mnp_arccosh(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_arctanh(x): + return mnp.arctanh(x) + + +def onp_arctanh(x): + return onp.arctanh(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arctanh(): + arr = onp.random.uniform(-0.9, 1, 10).astype('float32') + expect = onp_arctanh(arr) + actual = mnp_arctanh(to_tensor(arr)) + match_array(actual.asnumpy(), expect, error=5) + + +def mnp_arctan2(x, y): + return mnp.arctan2(x, y) + + +def onp_arctan2(x, y): + return onp.arctan2(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_arctan2(): + run_binop_test(mnp_arctan2, onp_arctan2, test_case) + + +def mnp_convolve(mode): + a = mnp.convolve([1, 2, 3, 4, 5], 2, mode=mode) + b = mnp.convolve([1, 2, 3, 4, 5], [2, 3], mode=mode) + c = mnp.convolve([1, 2], [2, 5, 10], mode=mode) + d = mnp.convolve(mnp.array([1, 2, 3, 4, 5]), mnp.array([1, 2, 3, 4, 5]), mode=mode) + e = mnp.convolve([1, 2, 3, 4, 5], 2, mode=mode) + return a, b, c, d, e + + +def onp_convolve(mode): + a = onp.convolve([1, 2, 3, 4, 5], 2, mode=mode) + b = onp.convolve([1, 2, 3, 4, 5], [2, 3], mode=mode) + c = onp.convolve([1, 2], [2, 5, 10], mode=mode) + d = onp.convolve(onp.array([1, 2, 3, 4, 5]), onp.array([1, 2, 3, 4, 5]), mode=mode) + e = onp.convolve([1, 2, 3, 4, 5], 2, mode=mode) + return a, b, c, d, e + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_convolve(): + for mode in ['full', 'same', 'valid']: + mnp_res = mnp_convolve(mode) + onp_res = onp_convolve(mode) + match_all_arrays(mnp_res, onp_res) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cov(): + x = onp.random.random((3, 4)).tolist() + mnp_res = mnp.cov(x) + onp_res = onp.cov(x) + match_all_arrays(mnp_res, onp_res, error=1e-5) + mnp_res = mnp.cov(x[0]) + onp_res = onp.cov(x[0]) + match_all_arrays(mnp_res, onp_res, error=1e-5) + w1 = [0, 1, 2, 3] + w2 = [4, 5, 6, 7] + mnp_res = mnp.cov(x, fweights=w1) + onp_res = onp.cov(x, fweights=w1) + match_all_arrays(mnp_res, onp_res, error=1e-5) + mnp_res = mnp.cov(x, aweights=w2) + onp_res = onp.cov(x, aweights=w2) + match_all_arrays(mnp_res, onp_res, error=1e-5) + mnp_res = mnp.cov(x, fweights=w1, aweights=w2) + onp_res = onp.cov(x, fweights=w1, aweights=w2) + match_all_arrays(mnp_res, onp_res, error=1e-5) + mnp_res = mnp.cov(x, fweights=w1, aweights=w2, ddof=3) + onp_res = onp.cov(x, fweights=w1, aweights=w2, ddof=3) + match_all_arrays(mnp_res, onp_res, error=1e-5) + mnp_res = mnp.cov(x, fweights=w1, aweights=w2, bias=True) + onp_res = onp.cov(x, fweights=w1, aweights=w2, bias=True) + match_all_arrays(mnp_res, onp_res, error=1e-5) + mnp_res = mnp.cov(x, fweights=w1[0:3], aweights=w2[0:3], rowvar=False, bias=True) + onp_res = onp.cov(x, fweights=w1[0:3], aweights=w2[0:3], rowvar=False, bias=True) + match_all_arrays(mnp_res, onp_res, error=1e-5) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_trapz(): + y = rand_int(2, 3, 4, 5) + match_res(mnp.trapz, onp.trapz, y) + match_res(mnp.trapz, onp.trapz, y, x=[-5, -3, 0, 7, 10]) + match_res(mnp.trapz, onp.trapz, y, dx=2, axis=3) + match_res(mnp.trapz, onp.trapz, y, x=[1, 5, 6, 9], dx=3, axis=-2) + + +def mnp_gcd(x, y): + return mnp.gcd(x, y) + + +def onp_gcd(x, y): + return onp.gcd(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_gcd(): + x = onp.arange(-12, 12).reshape(2, 3, 4) + y = onp.arange(24).reshape(2, 3, 4) + match_res(mnp_gcd, onp_gcd, x, y) + + +def mnp_lcm(x, y): + return mnp.lcm(x, y) + + +def onp_lcm(x, y): + return onp.lcm(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_lcm(): + x = onp.arange(-12, 12).reshape(2, 3, 4) + y = onp.arange(24).reshape(2, 3, 4) + match_res(mnp_lcm, onp_lcm, x, y) + + +def mnp_nansum(x): + a = mnp.nansum(x) + b = mnp.nansum(x, keepdims=True) + c = mnp.nansum(x, axis=-2) + d = mnp.nansum(x, axis=0, keepdims=True) + e = mnp.nansum(x, axis=(-2, 3)) + f = mnp.nansum(x, axis=(-3, -1), keepdims=True) + return a, b, c, d, e, f + + +def onp_nansum(x): + a = onp.nansum(x) + b = onp.nansum(x, keepdims=True) + c = onp.nansum(x, axis=-2) + d = onp.nansum(x, axis=0, keepdims=True) + e = onp.nansum(x, axis=(-2, 3)) + f = onp.nansum(x, axis=(-3, -1), keepdims=True) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_nansum(): + x = rand_int(2, 3, 4, 5) + x[0][2][1][3] = onp.nan + x[1][0][2][4] = onp.nan + x[1][1][1][1] = onp.nan + run_multi_test(mnp_nansum, onp_nansum, (x,)) + + +def mnp_nanmean(x): + a = mnp.nanmean(x) + b = mnp.nanmean(x, keepdims=True) + c = mnp.nanmean(x, axis=-2) + d = mnp.nanmean(x, axis=0, keepdims=True) + e = mnp.nanmean(x, axis=(-2, 3)) + f = mnp.nanmean(x, axis=(-3, -1), keepdims=True) + return a, b, c, d, e, f + + +def onp_nanmean(x): + a = onp.nanmean(x) + b = onp.nanmean(x, keepdims=True) + c = onp.nanmean(x, axis=-2) + d = onp.nanmean(x, axis=0, keepdims=True) + e = onp.nanmean(x, axis=(-2, 3)) + f = onp.nanmean(x, axis=(-3, -1), keepdims=True) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_nanmean(): + x = rand_int(2, 3, 4, 5) + x[0][2][1][3] = onp.nan + x[1][0][2][4] = onp.nan + x[1][1][1][1] = onp.nan + run_multi_test(mnp_nanmean, onp_nanmean, (x,)) + + +def mnp_mean(*arrs): + arr1 = arrs[0] + arr2 = arrs[1] + arr3 = arrs[2] + a = mnp.mean(arr1) + b = mnp.mean(arr2, keepdims=True) + c = mnp.mean(arr3, keepdims=False) + d = mnp.mean(arr2, axis=0, keepdims=True) + e = mnp.mean(arr3, axis=(0, -1)) + f = mnp.mean(arr3, axis=-1, keepdims=True) + return a, b, c, d, e, f + + +def onp_mean(*arrs): + arr1 = arrs[0] + arr2 = arrs[1] + arr3 = arrs[2] + a = onp.mean(arr1) + b = onp.mean(arr2, keepdims=True) + c = onp.mean(arr3, keepdims=False) + d = onp.mean(arr2, axis=0, keepdims=True) + e = onp.mean(arr3, axis=(0, -1)) + f = onp.mean(arr3, axis=-1, keepdims=True) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_mean(): + run_multi_test(mnp_mean, onp_mean, test_case.arrs, error=3) + run_multi_test(mnp_mean, onp_mean, test_case.expanded_arrs, error=3) + run_multi_test(mnp_mean, onp_mean, test_case.scalars, error=3) + run_multi_test(mnp_mean, onp_mean, test_case.empty_arrs, error=3) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -1166,8 +1920,8 @@ def test_cumsum(): @pytest.mark.env_onecard def test_exception_innner(): with pytest.raises(ValueError): - mnp.inner(mnp.asarray(test_case.arrs[0]), - mnp.asarray(test_case.arrs[1])) + mnp.inner(to_tensor(test_case.arrs[0]), + to_tensor(test_case.arrs[1])) @pytest.mark.level1 @@ -1178,7 +1932,7 @@ def test_exception_innner(): @pytest.mark.env_onecard def test_exception_add(): with pytest.raises(ValueError): - mnp.add(mnp.asarray(test_case.arrs[1]), mnp.asarray(test_case.arrs[2])) + mnp.add(to_tensor(test_case.arrs[1]), to_tensor(test_case.arrs[2])) @pytest.mark.level1 @@ -1189,4 +1943,4 @@ def test_exception_add(): @pytest.mark.env_onecard def test_exception_mean(): with pytest.raises(ValueError): - mnp.mean(mnp.asarray(test_case.arrs[0]), (-1, 0)) + mnp.mean(to_tensor(test_case.arrs[0]), (-1, 0)) diff --git a/tests/st/numpy_native/utils.py b/tests/st/numpy_native/utils.py index 8a0fdd5930..3568b24469 100644 --- a/tests/st/numpy_native/utils.py +++ b/tests/st/numpy_native/utils.py @@ -15,6 +15,7 @@ """utility functions for mindspore.numpy st tests""" import functools import numpy as onp +from mindspore import Tensor import mindspore.numpy as mnp @@ -90,7 +91,9 @@ def rand_bool(*shape): def match_res(mnp_fn, onp_fn, *arrs, **kwargs): """Checks results from applying mnp_fn and onp_fn on arrs respectively""" - mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) + dtype = kwargs.get('dtype', mnp.float32) + kwargs.pop('dtype', None) + mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs) error = kwargs.get('error', 0) kwargs.pop('error', None) mnp_res = mnp_fn(*mnp_arrs, **kwargs) @@ -151,15 +154,32 @@ def run_unary_test(mnp_fn, onp_fn, test_case, error=0): def run_multi_test(mnp_fn, onp_fn, arrs, error=0): - mnp_arrs = map(mnp.asarray, arrs) + mnp_arrs = map(Tensor, arrs) for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): - match_array(actual.asnumpy(), expected, error) + match_all_arrays(actual, expected, error) def run_single_test(mnp_fn, onp_fn, arr, error=0): - mnp_arr = mnp.asarray(arr) + mnp_arr = Tensor(arr) for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): if isinstance(expected, tuple): for actual_arr, expected_arr in zip(actual, expected): match_array(actual_arr.asnumpy(), expected_arr, error) match_array(actual.asnumpy(), expected, error) + + +def run_logical_test(mnp_fn, onp_fn, test_case): + for x1 in test_case.boolean_arrs: + for x2 in test_case.boolean_arrs: + match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_) + +def to_tensor(obj, dtype=None): + if dtype is None: + res = Tensor(obj) + if res.dtype == mnp.float64: + res = res.astype(mnp.float32) + if res.dtype == mnp.int64: + res = res.astype(mnp.int32) + else: + res = Tensor(obj, dtype) + return res