From c5ea8223f5b53a6eb97376c2e8fa56de6c76183e Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Thu, 24 Dec 2020 15:15:26 +0800 Subject: [PATCH] Add new np interfaces and add graph support --- mindspore/numpy/__init__.py | 45 +- mindspore/numpy/array_creations.py | 1240 ++++++---- mindspore/numpy/array_ops.py | 1238 ++++++++-- mindspore/numpy/dtypes.py | 5 + mindspore/numpy/logic_ops.py | 576 +++++ mindspore/numpy/math_ops.py | 2077 +++++++++++++++-- mindspore/numpy/utils.py | 86 +- mindspore/numpy/utils_const.py | 271 ++- mindspore/ops/functional.py | 8 + tests/st/numpy_native/test_array_creations.py | 309 ++- tests/st/numpy_native/test_array_ops.py | 550 ++++- tests/st/numpy_native/test_logic_ops.py | 263 +++ tests/st/numpy_native/test_math_ops.py | 883 +++++-- tests/st/numpy_native/utils.py | 165 ++ 14 files changed, 6528 insertions(+), 1188 deletions(-) create mode 100644 mindspore/numpy/logic_ops.py create mode 100644 tests/st/numpy_native/test_logic_ops.py create mode 100644 tests/st/numpy_native/utils.py diff --git a/mindspore/numpy/__init__.py b/mindspore/numpy/__init__.py index 726ef318ac..8ac9969a05 100644 --- a/mindspore/numpy/__init__.py +++ b/mindspore/numpy/__init__.py @@ -22,36 +22,59 @@ Note: - array_ops.py defines all the array operation interfaces. - array_creations.py defines all the array generation interfaces. - math_ops.py defines all the math operations on tensors. + - logic_ops.py defines all the logical operations on tensors. - dtypes.py defines all the mindspore.numpy dtypes (mainly redirected from mindspore) """ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape, ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d, - column_stack, hstack, dstack, vstack, stack, unique) + column_stack, hstack, dstack, vstack, stack, unique, moveaxis, + tile, broadcast_to, broadcast_arrays, roll, append, split, vsplit, + flip, flipud, fliplr, hsplit, dsplit, take_along_axis, take, repeat) from .array_creations import copy_ as copy from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange, linspace, logspace, eye, identity, empty, empty_like, ones_like, zeros_like, full_like, diagonal, tril, triu, - tri, trace) + tri, trace, cumsum, meshgrid, mgrid, ogrid, diagflat, + diag, diag_indices, ix_) from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16, - uint32, uint64, float_, float16, float32, float64, bool_, inf, - numeric_types) -from .math_ops import (mean, inner, add, subtract, multiply, divide, power, - dot, outer, tensordot, absolute) + uint32, uint64, float_, float16, float32, float64, bool_, inf, nan, + numeric_types, PINF, NINF) +from .math_ops import (mean, inner, add, subtract, multiply, divide, true_divide, power, + dot, outer, tensordot, absolute, std, var, average, minimum, + matmul, square, sqrt, reciprocal, log, maximum, heaviside, amax, amin, + hypot, float_power, floor, ptp, deg2rad, rad2deg, count_nonzero, + positive, negative, clip, floor_divide, remainder, fix, fmod, trunc, + exp, expm1) +from .logic_ops import (not_equal, less_equal, less, greater_equal, greater, equal, isfinite, + isnan, isinf, isposinf, isneginf, isscalar) +mod = remainder +fabs = absolute array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape', 'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d', - 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique'] + 'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique', 'moveaxis', + 'tile', 'broadcast_to', 'broadcast_arrays', 'append', 'roll', 'split', 'vsplit', + 'flip', 'flipud', 'fliplr', 'hsplit', 'dsplit', 'take_along_axis', 'take', + 'repeat'] array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange', 'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like', 'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu', - 'tri', 'trace'] + 'tri', 'trace', 'meshgrid', 'mgrid', 'ogrid', 'diagflat', 'diag', + 'diag_indices', 'ix_', 'cumsum'] -math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'power', - 'dot', 'outer', 'tensordot', 'absolute'] +math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'true_divide', 'power', + 'dot', 'outer', 'tensordot', 'absolute', 'std', 'var', 'average', 'not_equal', + 'minimum', 'matmul', 'square', 'sqrt', 'reciprocal', 'log', 'maximum', + 'heaviside', 'amax', 'amin', 'hypot', 'float_power', 'floor', 'ptp', 'deg2rad', + 'rad2deg', 'count_nonzero', 'positive', 'negative', 'clip', 'floor_divide', + 'remainder', 'mod', 'fix', 'fmod', 'trunc', 'exp', 'expm1', 'fabs'] -__all__ = array_ops_module + array_creations_module + math_module + numeric_types +logic_module = ['not_equal', 'less_equal', 'less', 'greater_equal', 'greater', 'equal', 'isfinite', + 'isnan', 'isinf', 'isposinf', 'isneginf', 'isscalar'] + +__all__ = array_ops_module + array_creations_module + math_module + logic_module + numeric_types __all__.sort() diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index 3cfb503e34..289d34e6fe 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -13,14 +13,14 @@ # limitations under the License. # ============================================================================ """array operations, the function docs are adapted from Numpy API.""" -from copy import copy as py_copy -from itertools import groupby +from copy import deepcopy import numpy as onp from ..common import Tensor from ..common import dtype as mstype from ..ops import functional as F +from ..ops import operations as P from ..ops.primitive import constexpr from ..nn.layer.basic import tril as nn_tril from ..nn.layer.basic import triu as nn_triu @@ -28,14 +28,20 @@ from .._c_expression import Tensor as Tensor_ from .._c_expression.typing import Float from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \ - _expand, _broadcast_to, _is_empty -from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, _check_same_type, \ - _check_shape_contain_zero, _check_shape, _check_dtype -from .array_ops import transpose + _expand, _broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar +from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \ + _check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \ + _raise_type_error, _expanded_shape, _check_axis_in_range, _check_is_float, _iota, \ + _type_convert, _canonicalize_axis, _list_comprehensions, _ceil +from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape +from .dtypes import nan # According to official numpy reference, the dimension of a numpy array must be less # than 32 MAX_NUMPY_DIMS = 32 +# All types that can be accepted as "array_like" parameters in graph mode. +ARRAY_TYPES = (int, float, bool, list, tuple, Tensor) +_cumsum_default = P.CumSum() def array(obj, dtype=None, copy=True, ndmin=0): @@ -46,13 +52,13 @@ def array(obj, dtype=None, copy=True, ndmin=0): Args: obj (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.int32, or `int32`. If dtype is None, the data type - of the new tensor will be inferred from obj. Default is None. - copy (bool): If true, then the object is copied. Otherwise, a copy will - only be made if necessary. Default: True. + any form that can be converted to a `Tensor`. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can + be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type + of the new tensor will be inferred from obj. Default is :class:`None`. + copy (bool): If `True`, then the object is copied. Otherwise, a copy will + only be made if necessary. Default: `True`. ndmin (int): Specifies the minimum number of dimensions that the resulting tensor should have. Ones will be pre-pended to the shape as needed to meet this requirement. Default: 0 @@ -81,7 +87,7 @@ def array(obj, dtype=None, copy=True, ndmin=0): if not copy: return asarray(obj, dtype=dtype) - obj = py_copy(obj) + obj = deepcopy(obj) return asarray(obj, dtype=dtype) @@ -93,11 +99,11 @@ def asarray(a, dtype=None): Args: a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and ndarrays. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.int32, or `int32`. If dtype is None, the data type - of the new tensor will be inferred from a. Default is None. + any form that can be converted to a `Tensor`. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can + be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type + of the new tensor will be inferred from obj. Default is :class:`None`. Returns: Tensor, generated tensor with the specified dtype. @@ -114,20 +120,13 @@ def asarray(a, dtype=None): >>> print(np.asarray([1,2,3])) [1 2 3] """ + _check_input_for_asarray(a) if dtype is not None: dtype = _check_dtype(dtype) - _ = _check_input_for_asarray(a) - - if isinstance(a, float) and (dtype is None): - dtype = mstype.float32 - - if isinstance(a, int) and not isinstance(a, bool) and (dtype is None): - dtype = mstype.int32 - - if isinstance(a, bool) and (dtype is None): - dtype = mstype.bool_ + if isinstance(a, (float, int, bool)) and dtype is None: + dtype = _get_dtype_from_scalar(a) if isinstance(a, (list, tuple)): # Convert all tuple/nested tuples to lists @@ -140,33 +139,29 @@ def asarray(a, dtype=None): # If dtype is not specified, we keep consistent with numpy decision # only exceptions are: we use int/float32 if dtype is None: - if a.dtype is onp.dtype('int64'): - dtype = mstype.int32 - elif a.dtype is onp.dtype('float64'): + dtype = mstype.pytype_to_dtype(a.dtype) + if dtype == mstype.float64: dtype = mstype.float32 + elif dtype == mstype.int64: + dtype = mstype.int32 if isinstance(a, onp.ndarray) and dtype is None: - if a.dtype is onp.dtype('bool'): - dtype = mstype.bool_ - elif a.dtype is onp.dtype('int'): - dtype = mstype.int32 - elif a.dtype is onp.dtype('float'): - dtype = mstype.float32 - elif a.dtype is onp.dtype('object'): + if a.dtype is onp.dtype('object'): raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") + dtype = mstype.pytype_to_dtype(a.dtype) a = Tensor.from_numpy(a) # If a is already a tensor and we don't need to cast dtype, return a if isinstance(a, Tensor): - if dtype is None: - return a - dtype = _check_dtype(dtype) - if dtype == a.dtype: + if dtype is None or dtype == a.dtype: return a return Tensor(a, dtype=dtype) +asarray_const = constexpr(asarray) + + def asfarray(a, dtype=mstype.float32): """ Similar to asarray, converts the input to a float tensor. @@ -175,10 +170,12 @@ def asfarray(a, dtype=mstype.float32): Args: a (Union[int, float, bool, list, tuple, numpy.ndarray]): Input data, in - any form that can be converted to a tensor. This includes lists, lists of - tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. + any form that can be converted to a `Tensor`. This includes lists, lists of + tuples, tuples, tuples of tuples, tuples of lists and numpy.ndarray. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, can + be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type + of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. + Returns: Tensor, generated tensor with the specified float dtype. @@ -195,9 +192,12 @@ def asfarray(a, dtype=mstype.float32): >>> print(np.asfarray([1,2,3])) [1. 2. 3.] """ - dtype = _check_dtype(dtype) - _ = _check_input_for_asarray(a) + _check_input_for_asarray(a) + if dtype is None: + return asarray(a) + + dtype = _check_dtype(dtype) if dtype not in (mstype.float16, mstype.float32, mstype.float64): dtype = mstype.float32 @@ -242,14 +242,8 @@ def copy_(a): [1. 1.]] """ if not isinstance(a, Tensor): - a = asarray(a) - return py_copy(a) - - -@constexpr -def _fill(shape, value, dtype): - """Original numpy.full function.""" - return Tensor(onp.full(shape, value), dtype) + a = asarray_const(a) + return a.copy() def ones(shape, dtype=mstype.float32): @@ -258,14 +252,15 @@ def ones(shape, dtype=mstype.float32): Args: shape (Union[int, tuple, list]): the shape of the new tensor. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype. + Default is :class:`mstype.float32`. Returns: - Tensor, with the designated shape and dtype, filled with ones. + Tensor, with the designated `shape` and `dtype`, filled with ones. Raises: TypeError: If input arguments have types not specified above. + ValueError: If `shape` entries have values :math:`< 0`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -278,8 +273,8 @@ def ones(shape, dtype=mstype.float32): """ shape = _check_shape(shape) dtype = _check_dtype(dtype) - if _check_shape_contain_zero(shape): - return _fill(shape, 1.0, dtype) + if _is_shape_empty(shape): + return full(shape, 1.0, dtype) output = F.fill(dtype, shape, 1) return output @@ -290,14 +285,15 @@ def zeros(shape, dtype=mstype.float32): Args: shape (Union[int, tuple, list]): the shape of the new tensor. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype. + Default is :class:`mstype.float32`. Returns: - Tensor, with the designated shape and dtype, filled with zeros. + Tensor, with the designated `shape` and `dtype`, filled with zeros. Raises: TypeError: If input arguments have types not specified above. + ValueError: If `shape` entries have values :math:`< 0`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -310,30 +306,31 @@ def zeros(shape, dtype=mstype.float32): """ shape = _check_shape(shape) dtype = _check_dtype(dtype) - if _check_shape_contain_zero(shape): - return _fill(shape, 0.0, dtype) + if _is_shape_empty(shape): + return full(shape, 0.0, dtype) output = F.fill(dtype, shape, 0) return output def full(shape, fill_value, dtype=None): """ - Returns a new tensor of given shape and type, filled with fill_value. + Returns a new tensor of given shape and type, filled with `fill_value`. Args: shape (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., - (2, 3) or 2. - fill_value (Union[int, float, bool, list, tuple]): scalar or array_like + :math:`(2, 3)` or :math:`2`. + fill_value (Union[int, float, bool, list, tuple]): Scalar or array_like fill value. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`, if dtype is None, the data type - of the new tensor will be inferred from fill_value. Default is None. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, + if `dtype` is :class:`None`, the data type of the new tensor will be inferred from + `fill_value`. Default is :class:`None`. Returns: Tensor, with the designated shape and dtype, filled with `fill_value`. Raises: TypeError: If input arguments have types not specified above. + ValueError: If `shape` has entries < 0. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -344,53 +341,59 @@ def full(shape, fill_value, dtype=None): [[True True] [True True]] """ - if dtype is None: - dtype = array(fill_value).dtype - shape = _check_shape(shape) - _ = _check_input_for_asarray(fill_value) - dtype = _check_dtype(dtype) - - if isinstance(fill_value, (int, float, bool)) and not _check_shape_contain_zero(shape): - return F.fill(dtype, shape, fill_value) - - # if fill_value is array_like or shape contains zero. fall back to original - # numpy creation - return Tensor(onp.full(shape, fill_value, mstype.dtype_to_nptype(dtype))) - - -def arange(*args, **kwargs): + if not isinstance(fill_value, ARRAY_TYPES): + _raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value) + if dtype is not None: + dtype = _check_dtype(dtype) + else: + if isinstance(fill_value, (int, float, bool)): + dtype = _get_dtype_from_scalar(fill_value) + if isinstance(fill_value, Tensor): + dtype = fill_value.dtype + + if not _is_shape_empty(shape): + if isinstance(fill_value, (int, float, bool)): + return F.fill(dtype, shape, fill_value) + if isinstance(fill_value, (list, tuple)): + fill_value = asarray_const(fill_value) + if isinstance(fill_value, Tensor): + fill_value = _expand(fill_value, len(shape)) + return F.tile(fill_value, _tile_size(fill_value.shape, shape, len(shape))) + # if shape contains zero, use c.Tensor() + return _convert_64_to_32(empty_compile(dtype, shape)) + + +def arange(start, stop=None, step=None, dtype=None): """ Returns evenly spaced values within a given interval. - Returns `num` evenly spaced samples, calculated over the interval [`start`, `stop`]. + Returns `num` evenly spaced samples, calculated over the interval :math:`[start, stop]`. The endpoint of the interval can optionally be excluded. - The current implementation is a direct wrapper on top of numpy.arange, except that - the default dtype is float32 and int32, compare to float64 and int64 for numpy - implementation. Args: start(Union[int, float]): Start of interval. The interval includes this value. - When stop is provided as a position argument, start must be given, when stop - is a normal argument, start can be optional, and default is 0. + When `stop` is provided as a position argument, `start` must be given, when `stop` + is a normal argument, `start` can be optional, and default is 0. Please see additional examples below. stop(Union[int, float], optional): End of interval. The interval does not - include this value, except in some cases where step is not an integer + include this value, except in some cases where `step` is not an integer and floating point round-off affects the length of out. step(Union[int, float], optional): Spacing between values. For any output - out, this is the distance between two adjacent values, out[i+1] - out[i]. - The default step size is 1. If step is specified as a position argument, - start must also be given. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. If dtype is None, the data type - of the new tensor will be inferred from start, stop and step. Default is None. + `out`, this is the distance between two adjacent values, :math:`out[i+1] - out[i]`. + The default step size is 1. If `step` is specified as a position argument, + `start` must also be given. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype. + If dtype is None, the data type of the new tensor will be inferred from start, + stop and step. Default is None. Returns: - arangend tensor of evenly spaced values. + Tensor with evenly spaced values. Raises: - TypeError: If input arguments have types not specified above, or arguments are - not given in the correct orders specified above. + TypeError(PyNative Mode) or RuntimeError(Graph Mode): If input arguments + have types not specified above, or arguments are not given in the correct + orders specified above. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -407,68 +410,75 @@ def arange(*args, **kwargs): [0. 0.5 1. 1.5 2. 2.5] >>> print(np.arange(stop=3)) # This will lead to TypeError """ - # infer the dtype, if either of start, end, step is float, default dtype is - # float32, else int32. - int_flag = True - final_dtype = None - - if args: - for item in args: - if isinstance(item, float): - int_flag = False - if kwargs: - if ('start' in kwargs and isinstance(kwargs['start'], float)) or \ - ('stop' in kwargs and isinstance(kwargs['stop'], float)) or \ - ('step' in kwargs and isinstance(kwargs['step'], float)): - int_flag = False - - if int_flag: - final_dtype = onp.int32 + # This implementation was inspired by jax.numpy.arange + # infer the dtype + if dtype is None: + dtype = _get_dtype_from_scalar(start, stop, step) + if stop is None and step is None: # (start, stop, step) -> (0, start, 1) + num = _ceil(start) + out = _iota(mstype.float32, num) + elif step is None: # (start, stop, step) -> (start, stop, 1) + num = _ceil(stop - start) + out = _iota(mstype.float32, num) + start + elif stop is None: # (start, stop, step) -> (0, start, step) + num = _ceil(start / step) + out = _iota(mstype.float32, num) * step else: - final_dtype = onp.float32 - - if 'dtype' in kwargs and kwargs['dtype'] is not None: - final_dtype = _check_dtype(kwargs['dtype']) - final_dtype = mstype.dtype_to_nptype(final_dtype) - kwargs['dtype'] = final_dtype - out = onp.arange(*args, **kwargs) - out = Tensor.from_numpy(out) - return out + num = _ceil((stop - start) / step) + out = _iota(mstype.float32, num) * step + start + return out.astype(dtype) + + +def _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis): + """utility parameter checking function for linspace, logspace, geomspace.""" + if not isinstance(start, ARRAY_TYPES): + _raise_type_error("start should be int, float, bool, list, tuple, Tensor, but got", start) + if not isinstance(stop, ARRAY_TYPES): + _raise_type_error("end should be int, float, bool, list, tuple, Tensor, but got", stop) + if not isinstance(start, Tensor): + start = _type_convert(Tensor, start).astype(mstype.float32) + if not isinstance(stop, Tensor): + stop = _type_convert(Tensor, stop).astype(mstype.float32) + if not isinstance(num, int): + _raise_type_error("num should be an integer, but got ", num) + if not isinstance(endpoint, bool): + _raise_type_error("endpoint should be an boolean, but got ", endpoint) + if dtype is not None: + dtype = _check_dtype(dtype) + else: + dtype = mstype.float32 + axis = _canonicalize_axis(axis, start.ndim+1) + return start, stop, num, endpoint, dtype, axis def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): """ Returns evenly spaced values within a given interval. - The current implementation is a direct wrapper on top of numpy.linspace, except - the default dtype is float32, compare to float64 for numpy, - Args: - start (Union[int, list(int), tuple(int),tensor]):The starting value of the sequence. - stop (Union[int, list(int), tuple(int),tensor]):The end value of the sequence, + start (Union[int, list(int), tuple(int), tensor]): The starting value of the sequence. + stop (Union[int, list(int), tuple(int), tensor]): The end value of the sequence, unless `endpoint` is set to False. In that case, the sequence consists - of all but the last of ``num + 1` evenly spaced samples, so that `stop` + of all but the last of `num + 1` evenly spaced samples, so that `stop` is excluded. Note that the step size changes when `endpoint` is False. num (int, optional): Number of samples to generate. Default is 50. endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is not included. Default is True. retstep (bool, optional): If True, return (`samples`, `step`), where `step` is the spacing between samples. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`.If `dtype` is None, infer the data - type from other input arguments. Default is None. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, + If `dtype` is None, infer the data type from other input arguments. Default is None. axis (int, optional): The axis in the result to store the samples. Relevant - only if start or stop are array-like. By default (0), the samples will - be along a new axis inserted at the beginning. Use -1 to get an axis at the end. - Default is 0. + only if start or stop are array-like. By default :math:`(0)`, the samples will + be along a new axis inserted at the beginning. Use :math:`-1` to get an axis at the end. + Default is :math:`0`. Returns: - samples (Tensor): There are `num` equally spaced samples in the closed interval - ``[start, stop]`` or the half-open interval ``[start, stop)`` - (depending on whether `endpoint` is True or False). + Tensor, with `num` equally spaced samples in the closed interval + :math:`[start, stop]` or the half-open interval :math:`[start, stop)` + (depending on whether `endpoint` is True or False). - step (float, optional): Only returned if `retstep` is True. - Size of spacing between samples. + Step, the size of spacing between samples, only returned if `retstep` is True. Raises: TypeError: If input arguments have types not specified above. @@ -481,33 +491,37 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis >>> print(np.linspace(0, 5, 6)) [0. 1. 2. 3. 4. 5.] """ - - if isinstance(start, Tensor): - start = start.asnumpy() - - if isinstance(stop, Tensor): - stop = stop.asnumpy() - - if not isinstance(num, int): - raise TypeError(f"num should be an integer, but got {type(num)}") - - final_dtype = None - if dtype is not None: - final_dtype = _check_dtype(dtype) - final_dtype = mstype.dtype_to_nptype(final_dtype) - else: - final_dtype = onp.float32 - - dtype = final_dtype - out = onp.linspace(start, stop, num, endpoint, retstep, dtype, axis) - + # This implementation was inspired by jax.numpy.linspace and numpy.linspace + start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis) + if not isinstance(retstep, bool): + _raise_type_error("retstep should be an boolean, but got ", retstep) + start, stop = broadcast_arrays(start, stop) + axis = _canonicalize_axis(axis, start.ndim+1) + bounds_shape = start.shape + bounds_shape = bounds_shape[:axis] + (1,) + bounds_shape[axis:] + iota_shape = _list_comprehensions(start.ndim+1, 1, True) + iota_shape = iota_shape[:axis] + (num,) + iota_shape[axis+1:] + num_tensor = _type_convert(Tensor, num).astype(mstype.float32) + div = (num_tensor - 1) if endpoint else num_tensor + + if num > 1: + delta = (stop - start) / div + # This is similar to how numpy and jax compute linspace + start_expand = reshape(start, bounds_shape) + incremental_expand = reshape(_iota(mstype.float32, num), iota_shape) + delta_expand = reshape(delta, bounds_shape) + start_expand, incremental_expand, delta_expand = broadcast_arrays( + start_expand, incremental_expand, delta_expand) + out = start_expand + (incremental_expand * delta_expand) + elif num == 1: + delta = nan if endpoint else stop - start + out = reshape(start, bounds_shape) + else: # num == 0 + delta = nan + out = _type_convert([], Tensor).astype(dtype) if retstep: - array_out, step_out = out[0], out[1] - tensor_out = Tensor(array_out) - return tensor_out, step_out - - tensor_out = Tensor.from_numpy(out) - return tensor_out + return out.astype(dtype), delta + return out.astype(dtype) def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): @@ -516,31 +530,28 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): In linear space, the sequence starts at base ** start (base to the power of start) and ends with base ** stop (see endpoint below). - The current implementation is a direct wrapper on top of numpy.logspace, except - the default dtype is float32, compare to float64 for numpy, Args: - start (Union[int, list(int), tuple(int), tensor]):The starting value of the sequence. - stop (Union[int, list(int), tuple(int), tensor]):The end value of the sequence, + start (Union[int, list(int), tuple(int), tensor]): The starting value of the sequence. + stop (Union[int, list(int), tuple(int), tensor]): The end value of the sequence, unless `endpoint` is set to False. In that case, the sequence consists - of all but the last of ``num + 1` evenly spaced samples, so that `stop` + of all but the last of `num + 1` evenly spaced samples, so that `stop` is excluded. Note that the step size changes when `endpoint` is False. num (int, optional): Number of samples to generate. Default is 50. endpoint (bool, optional): If True, `stop` is the last sample. Otherwise, it is not included. Default is True. base (Union[int, float], optional): The base of the log space. The step size - between the elements in ln(samples) / ln(base) (or log_base(samples)) - is uniform. Default is 10.0. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`.If `dtype` is None, infer the data - type from other input arguments. Default is None. + between the elements in :math:`ln(samples) / ln(base)` (or :math:`log_{base}(samples)`) + is uniform. Default is :math:`10.0`. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype. + If `dtype` is None, infer the data type from other input arguments. Default is None. axis (int, optional): The axis in the result to store the samples. Relevant - only if start or stop is array-like. By default (0), the samples will - be along a new axis inserted at the beginning. Use -1 to get an axis at the end. - Default is 0. + only if start or stop is array-like. By default (:math:`0`), the samples will + be along a new axis inserted at the beginning. Use :math:`-1` to get an axis at the end. + Default is :math:`0`. Returns: - samples (Tensor): num samples, equally spaced on a log scale. + Tensor, equally spaced on a log scale. Raises: TypeError: If input arguments have types not specified above. @@ -553,25 +564,12 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): >>> print(np.logspace(0, 5, 6, base=2.0)) [ 1. 2. 4. 8. 16. 32.] """ - - if isinstance(start, Tensor): - start = start.asnumpy() - - if isinstance(stop, Tensor): - stop = stop.asnumpy() - - final_dtype = None - if dtype is not None: - final_dtype = _check_dtype(dtype) - final_dtype = mstype.dtype_to_nptype(final_dtype) - else: - final_dtype = onp.float32 - - dtype = final_dtype - out = onp.logspace(start, stop, num, endpoint, base, dtype, axis) - - tensor_out = Tensor.from_numpy(out) - return tensor_out + # This implementation was inspired by jax.numpy.linspace and numpy.linspace + start, stop, num, endpoint, dtype, axis = _type_checking_for_xspace(start, stop, num, endpoint, dtype, axis) + if not isinstance(base, (int, float, bool)): + _raise_type_error("base should be a number, but got ", base) + linspace_res = linspace(start, stop, num, endpoint=endpoint, retstep=False, dtype=None, axis=axis) + return F.tensor_pow(base, linspace_res).astype(dtype) def eye(N, M=None, k=0, dtype=mstype.float32): @@ -580,17 +578,17 @@ def eye(N, M=None, k=0, dtype=mstype.float32): Args: N (int): Number of rows in the output, must be larger than 0. - M (int, optional): Number of columns in the output. If None, defaults to N, - if defined, must be larger than 0. Deault is None. + M (int, optional): Number of columns in the output. If is :class:`None`, defaults to `N`, + if defined, must be larger than 0. Deault is :class:`None`. k (int, optional): Index of the diagonal: 0 (the default) refers to the main diagonal, a positive value refers to an upper diagonal, and a negative value to a lower diagonal. Default is 0. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype. + Default is mstype.float32. Returns: - result (Tensor): A tensor of shape (N,M). A tensor where all elements - are equal to zero, except for the k-th diagonal, whose values are equal to one. + A tensor of shape (N, M). A tensor where all elements are equal to zero, + except for the k-th diagonal, whose values are equal to one. Raises: TypeError: If input arguments have types not specified above. @@ -608,14 +606,27 @@ def eye(N, M=None, k=0, dtype=mstype.float32): if M is None: M = N if not (isinstance(M, int) and isinstance(N, int) and isinstance(k, int)): - raise TypeError("Input tensor dimensions should be integers.") + _raise_type_error("Input tensor dimensions should be integers.") out = None - if k != 0 or N == 0 or M == 0: - # Fall back to original numpy creation method - out = onp.eye(N, M, k) - else: - out = F.eye(N, M, dtype) - return asarray(out, dtype=dtype) + if N == 0 or M == 0: + # Fill the shape with any value is fine. + return full((N, M), 0, dtype) + + out = F.eye(N, M, dtype) + + if k >= M or k <= -N: + return full((N, M), 0, dtype) + if k != 0: + out = out.astype(mstype.float32) + if k > 0: + out_left = full((N, k), 0, dtype) + out_right = out[..., 0:M-k:1] + return concatenate((out_left, out_right), 1).astype(dtype) + if k < 0: + out_upper = full((-k, M), 0, dtype) + out_lower = out[0:N+k:1, ...] + return concatenate((out_upper, out_lower), 0).astype(dtype) + return out def identity(n, dtype=mstype.float32): @@ -624,12 +635,12 @@ def identity(n, dtype=mstype.float32): Args: n (int): Number of rows and columns in the output, must be larger than 0. - dtype (Union[mstype.dtype, str], optional): Designated tensor dtype, can - be in format of np.float32, or `float32`. Default is mstype.float32. + dtype (Union[:class:`mindspore.dtype`, str], optional): Designated tensor dtype, + default is :class:`mstype.float32`. Returns: - result (Tensor): A tensor of shape (n,n). A tensor where all elements - are equal to zero, except for the diagonal, whose values are equal to one. + A tensor of shape `(n, n)`, where all elements are equal to zero, + except for the diagonal, whose values are equal to one. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -643,24 +654,31 @@ def identity(n, dtype=mstype.float32): [[1. 0.] [0. 1.]] """ + if not isinstance(n, int): + _raise_type_error("Input tensor dimensions should be integers.") dtype = _check_dtype(dtype) return eye(n, dtype=dtype) +@constexpr +def empty_compile(dtype, shape): + return Tensor_(dtype, shape) + + def empty(shape, dtype=mstype.float32): """ Returns a new array of given shape and type, without initializing entries. Note: - Numpy argument order is not supported. + Numpy argument `order` is not supported. Object arrays are not supported. Args: - shape (int or tuple of int): Shape of the empty array, e.g., + shape (Union[int, tuple(int)]): Shape of the empty array, e.g., (2, 3) or 2. - dtype (data-type): optional. Desired output data-type for the - array, e.g, numpy.int8. Default is numpy.float32. + dtype (:class:`mindspore.dtype`, optional): Desired output data-type for the + array, e.g, mstype.int8. Default is mstype.float32. Returns: Tensor, array of uninitialized (arbitrary) data of the given @@ -669,7 +687,6 @@ def empty(shape, dtype=mstype.float32): Raises: TypeError: if the input shape or dtype is invalid. - Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -682,39 +699,21 @@ def empty(shape, dtype=mstype.float32): """ shape = _check_shape(shape) dtype = _check_dtype(dtype) - return Tensor_(dtype, shape) - - -def _shape_matched(fn, arr): - """Returns the matched shape of elements in arr""" - shapes_all = groupby(map(fn, arr)) - shape = next(shapes_all)[0] - if next(shapes_all, False): - return _raise_value_error('Input array must have the same size across a dimension.') - return shape + return empty_compile(dtype, shape) def _get_shape(array_like): - """Returns the shape of the array like object by recursion.""" + """Returns the shape of the array like object.""" if isinstance(array_like, Tensor): - return F.shape(array_like) - if isinstance(array_like, onp.ndarray): return array_like.shape - if isinstance(array_like, (list, tuple)): - shape = _shape_matched(_get_shape, array_like) - return (len(array_like),) + shape - return () + return asarray_const(array_like).shape def _get_dtype(array_like): """Returns the data type of the array like object.""" if isinstance(array_like, Tensor): - return F.dtype(array_like) - if isinstance(array_like, onp.ndarray): - return mstype.pytype_to_dtype(array_like.dtype) - if isinstance(array_like, (list, tuple)): - return asarray(array_like).dtype - return mstype.float32 + return array_like.dtype + return asarray_const(array_like).dtype def _x_like(prototype, dtype, shape, constructor, fill_value=None): @@ -722,12 +721,13 @@ def _x_like(prototype, dtype, shape, constructor, fill_value=None): Returns a tensor with the same shape and type as prototype, using constructor. """ - _ = _check_input_for_asarray(prototype) + if not isinstance(prototype, ARRAY_TYPES): + _raise_type_error("prototype should be int, float, bool, list, tuple, Tensor, but got", prototype) dtype_out = dtype shape_out = shape - if not dtype_out: + if dtype_out is None: dtype_out = _get_dtype(prototype) - if not shape_out and shape_out != 0: + if shape_out is None or isinstance(shape_out, (list, tuple)) and not shape_out: shape_out = _get_shape(prototype) if fill_value is not None: return constructor(shape_out, fill_value, dtype_out) @@ -739,26 +739,23 @@ def empty_like(prototype, dtype=None, shape=None): Returns a new array with the same shape and type as a given array. Note: - Since list or tuple arrays are not supported, input array - must have the same size across a dimension. - If prototype is not a Tensor or a numpy array, dtype is - float32 by default if not provided. + Input array must have the same size across a dimension. + If `prototype` is not a Tensor, dtype is float32 by default if not provided. Args: - prototype (array_like): The shape and data-type of prototype + prototype (Union[Tensor, list, tuple]): The shape and data-type of `prototype` define these same attributes of the returned array. - dtype (data-type): optional. Overrides the data type of the + dtype (:class:`mindspore.dtype`, optional): Overrides the data type of the result. - shape (int or sequence of ints): optional. Overrides the shape + shape (int or sequence of ints, optional): Overrides the shape of the result. Returns: Tensor, array of uninitialized (arbitrary) data with the same - shape and type as prototype. + shape and type as `prototype`. Raises: - ValueError: if prototype does not have the same shape across each - dimension. + ValueError: if `prototype` is not a Tensor, list or tuple. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -779,25 +776,22 @@ def ones_like(a, dtype=None, shape=None): Returns an array of ones with the same shape and type as a given array. Note: - Since list or tuple arrays are not supported, input array - must have the same size across a dimension. - If a is not a Tensor or a numpy array, dtype is float32 by default - if not provided. + Input array must have the same size across a dimension. + If `a` is not a Tensor, dtype is float32 by default if not provided. Args: - a (array_like): The shape and data-type of a define these same + a (Union[Tensor, list, tuple]): The shape and data-type of a define these same attributes of the returned array. - dtype (data-type): optional. Overrides the data type of the + dtype (:class:`mindspore.dtype`, optional): Overrides the data type of the result. - shape (int or sequence of ints): optional. Overrides the shape - of the result. + shape (int or sequence of ints, optional): Overrides the shape + of the result. Returns: - Tensor, array of ones with the same shape and type as a. + Tensor, array of ones with the same shape and type as `a`. Raises: - ValueError: if prototype does not have the same shape across each - dimension. + ValueError: if `a` is not a Tensor, list or tuple. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -808,11 +802,8 @@ def ones_like(a, dtype=None, shape=None): >>> output = np.ones_like(a) >>> print(output) [[[1. 1.]] - [[1. 1.]] - [[1. 1.]] - [[1. 1.]]] """ return _x_like(a, dtype, shape, ones) @@ -823,25 +814,22 @@ def zeros_like(a, dtype=None, shape=None): Returns an array of zeros with the same shape and type as a given array. Note: - Since list or tuple arrays are not supported, input array - must have the same size across a dimension. - If a is not a Tensor or a numpy array, dtype is float32 by default - if not provided. + Input array must have the same size across a dimension. + If `a` is not a Tensor, dtype is float32 by default if not provided. Args: - a (array_like): The shape and data-type of a define these same + a (Union[Tensor, list, tuple]): The shape and data-type of a define these same attributes of the returned array. - dtype (data-type): optional. Overrides the data type of the + dtype (:class:`mindspore.dtype`, optional): Overrides the data type of the result. - shape (int or sequence of ints): optional. Overrides the shape + shape (int or sequence of ints, optional): Overrides the shape of the result. Returns: - Tensor, array of zeros with the same shape and type as a. + Tensor, array of zeros with the same shape and type as `a`. Raises: - ValueError: if prototype does not have the same shape across each - dimension. + ValueError: if `a` is not a Tensor, list or tuple. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -852,11 +840,8 @@ def zeros_like(a, dtype=None, shape=None): >>> output = np.zeros_like(a) >>> print(output) [[[0. 0.]] - [[0. 0.]] - [[0. 0.]] - [[0. 0.]]] """ return _x_like(a, dtype, shape, zeros) @@ -867,26 +852,23 @@ def full_like(a, fill_value, dtype=None, shape=None): Returns a full array with the same shape and type as a given array. Note: - Since list or tuple arrays are not supported, input array - must have the same size across a dimension. - If a is not a Tensor or a numpy array, dtype is float32 by default - if not provided. + Input array must have the same size across a dimension. + If `a` is not a Tensor, dtype is float32 by default if not provided. Args: - a (array_like): The shape and data-type of a define these same + a (Union[Tensor, list, tuple]): The shape and data-type of `a` define these same attributes of the returned array. fill_value (scalar): Fill value. - dtype (data-type): optional. Overrides the data type of the + dtype (:class:`mindspore.dtype`, optional): Overrides the data type of the result. - shape (int or sequence of ints): optional. Overrides the shape + shape (int or sequence of ints, optional): Overrides the shape of the result. Returns: - Tensor, array of fill_value with the same shape and type as a. + Tensor, array of fill_value with the same shape and type as `a`. Raises: - ValueError: if prototype does not have the same shape across each - dimension. + ValueError: if `a` is not a Tensor, list or tuple. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -897,11 +879,8 @@ def full_like(a, fill_value, dtype=None, shape=None): >>> output = np.full_like(a, 0.5) >>> print(output) [[[0.5 0.5]] - [[0.5 0.5]] - [[0.5 0.5]] - [[0.5 0.5]]] """ return _x_like(a, dtype, shape, full, fill_value=fill_value) @@ -909,22 +888,22 @@ def full_like(a, fill_value, dtype=None, shape=None): def tri(N, M=None, k=0, dtype=mstype.float32): """ - Returns an array with ones at and below the given diagonal and zeros elsewhere. + Returns a tensor with ones at and below the given diagonal and zeros elsewhere. Args: N(int): Number of rows in the array. - M(int, optional): Number of columns in the array. By default, M is taken + M(int, optional): Number of columns in the array. By default, `M` is taken equal to N. k(int, optional): The sub-diagonal at and below which the array is filled. - k = 0 is the main diagonal, while k < 0 is below it, and k > 0 is above. + :math:`k = 0` is the main diagonal, while :math:`k < 0` is below it, and :math:`k > 0` is above. The default is 0. - dtype(mstype.dtype, optional): Data type of the returned array. The default - is mstype.float32. + dtype(:class:`mindspore.dtype`, optional): Data type of the returned array. The default + is :class:`mindspore.dtype`. Returns: - tri(Tensor): Tensor with shape (N, M), with its lower triangle filled with - ones and zeros elsewhere; in other words T[i,j] == 1 for j <= i + k, - 0 otherwise. + Tensor with shape `(N, M)`, with its lower triangle filled with + ones and zeros elsewhere; in other words :math:`T[i,j] = 1` for :math:`j <= i + k`, + :math:`0` otherwise. Raises: TypeError: If input arguments have types not specified above. @@ -933,12 +912,12 @@ def tri(N, M=None, k=0, dtype=mstype.float32): ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> import mindspore.numpy as np - >>> output = np.tri(3, 3, 1) - >>> print(output) - [[1. 1. 0.] - [1. 1. 1.] - [1. 1. 1.]] + >>> import mindspore.numpy as np + >>> output = np.tri(3, 3, 1) + >>> print(output) + [[1. 1. 0.] + [1. 1. 1.] + [1. 1. 1.]] """ if M is None: M = N @@ -947,77 +926,77 @@ def tri(N, M=None, k=0, dtype=mstype.float32): def tril(m, k=0): """ - Returns a lower triangle of an array. + Returns a lower triangle of a tensor. - Returns a copy of an array with elements above the k-th diagonal zeroed. + Returns a copy of a tensor with elements above the `k-th` diagonal zeroed. Args: - m(array_like): The shape and data-type of m define these same - attributes of the returned array. - k(int, optional): Diagonal above which to zero elements. k = 0 (the default) - is the main diagonal, k < 0 is below it and k > 0 is above. + m (Union[Tensor, list, tuple]): The shape and data-type of `m` define these same + attributes of the returned tensor. + k (int, optional): Diagonal above which to zero elements. :math:`k = 0` (the default) + is the main diagonal, :math:`k < 0` is below it and :math:`k > 0` is above. Returns: - tril(Tensor): Lower triangle of m, of same shape and data-type as m. + Lower triangle of `m`, of same shape and data-type as `m`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Raises: TypeError: If input arguments have types not specified above. - ValueError: If input m's rank < 1. + ValueError: If input `m`\'s rank :math:`< 1`. Examples: - >>> import mindspore.numpy as np - >>> output = np.tril(np.ones((3, 3))) - >>> print(output) - [[1. 0. 0.] - [1. 1. 0.] - [1. 1. 1.]] + >>> import mindspore.numpy as np + >>> output = np.tril(np.ones((3, 3))) + >>> print(output) + [[1. 0. 0.] + [1. 1. 0.] + [1. 1. 1.]] """ - m = asarray(m) - shape = _get_shape(m) - dtype = _get_dtype(m) + if not isinstance(m, Tensor): + m = asarray_const(m) + dtype = m.dtype m = m.astype(mstype.float32) - assist = nn_tril(shape, mstype.float32, k) + assist = nn_tril(m.shape, mstype.float32, k) return F.tensor_mul(assist, m).astype(dtype) def triu(m, k=0): """ - Returns an upper triangle of an array. + Returns an upper triangle of a tensor. - Returns a copy of an array with elements below the k-th diagonal zeroed. + Returns a copy of a tensor with elements below the `k-th` diagonal zeroed. Args: - m(array_like): The shape and data-type of m define these same - attributes of the returned array. - k(int, optional): Diagonal below which to zero elements. k = 0 (the default) - is the main diagonal, k < 0 is below it and k > 0 is above. + m (Union[Tensor, list, tuple]): The shape and data-type of `m` define these same + attributes of the returned tensor. + k (int, optional): Diagonal below which to zero elements. :math:`k = 0` (the default) + is the main diagonal, :math:`k < 0` is below it and :math:`k > 0` is above. Returns: - triu(Tensor): Upper triangle of m, of same shape and data-type as m. + Upper triangle of `m`, of same shape and data-type as `m`. Raises: TypeError: If input arguments have types not specified above. - ValueError: If input m's rank < 1. + ValueError: If input `m`\'s rank < 1. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> import mindspore.numpy as np - >>> output = np.triu(np.ones((3, 3))) - >>> print(output) - [[1. 1. 1.] - [0. 1. 1.] - [0. 0. 1.]] + >>> import mindspore.numpy as np + >>> output = np.triu(np.ones((3, 3))) + >>> print(output) + [[1. 1. 1.] + [0. 1. 1.] + [0. 0. 1.]] """ - m = asarray(m) - shape = _get_shape(m) - dtype = _get_dtype(m) + if not isinstance(m, Tensor): + m = asarray_const(m) + dtype = m.dtype m = m.astype(mstype.float32) - assist = nn_triu(shape, mstype.float32, k) + assist = nn_triu(m.shape, mstype.float32, k) return F.tensor_mul(assist, m).astype(dtype) @@ -1025,27 +1004,27 @@ def diagonal(a, offset=0, axis1=0, axis2=1): """ Returns specified diagonals. - If `a` is 2-D, returns the diagonal of a with the given offset, i.e., the - collection of elements of the form a[i, i+offset]. If `a` has more than two - dimensions, then the axes specified by axis1 and axis2 are used to determine + If `a` is 2-D, returns the diagonal of `a` with the given offset, i.e., the + collection of elements of the form ``a[i, i+offset]``. If `a` has more than two + dimensions, then the axes specified by `axis1` and `axis2` are used to determine the 2-D sub-array whose diagonal is returned. The shape of the resulting - array can be determined by removing axis1 and axis2 and appending an index + array can be determined by removing `axis1` and `axis2` and appending an index to the right equal to the size of the resulting diagonals. Args: a (Tensor): Array from which the diagonals are taken. - offset (int): optional. Offset of the diagonal from the main diagonal. - Can be positive or negative. Defaults to main diagonal (0). - axis1 (int): optional. Axis to be used as the first axis of the 2-D + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal. + axis1 (int, optional): Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to first axis (0). - axis2 (int): optional. Axis to be used as the second axis of the 2-D + axis2 (int, optional): Axis to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to - second axis (1). + second axis. Returns: - Tensor, if `a` is 2-D, then a 1-D array containing the diagonal. If - a.ndim > 2, then the dimensions specified by axis1 and axis2 are removed, + Tensor, if `a` is 2-D, then `a` 1-D array containing the diagonal. If + ``a.ndim > 2``, then the dimensions specified by `axis1` and `axis2` are removed, and a new axis inserted at the end corresponding to the diagonal. Raises: @@ -1069,7 +1048,6 @@ def diagonal(a, offset=0, axis1=0, axis2=1): >>> print(a) [[[0 1] [2 3]] - [[4 5] [6 7]]] >>> output = np.diagonal(a, 0, 0, 1) @@ -1082,11 +1060,11 @@ def diagonal(a, offset=0, axis1=0, axis2=1): return _raise_value_error('diagonal requires an array of at least two dimensions') dtype = F.dtype(a) - if _is_empty(F.shape(a)): + if _is_shape_empty(F.shape(a)): return _empty(dtype, (0,)) cast_type = dtype - if not isinstance(dtype, Float): + if not _check_is_float(dtype): # reduce_sum only supports float types cast_type = mstype.float32 a = F.cast(a, cast_type) @@ -1101,9 +1079,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1): shape = F.shape(a) n, m = shape[-2:] - e = _eye(n, m, offset, cast_type) - e = _expand(e, ndim) - e = _broadcast_to(e, F.shape(e), F.shape(a), ndim) + e = eye(n, m, offset, cast_type) + e = _broadcast_to_shape(e, F.shape(a)) prod = F.tensor_mul(a, e) res = F.reduce_sum(prod, -1) @@ -1125,39 +1102,37 @@ def diagonal(a, offset=0, axis1=0, axis2=1): return res -@constexpr -def _eye(N, M, k, dtype): - return eye(N=N, M=M, k=k, dtype=dtype) - - -def trace(a, offset=0, axis1=0, axis2=1): +def trace(a, offset=0, axis1=0, axis2=1, dtype=None): """ Returns the sum along diagonals of the array. If `a` is 2-D, the sum along its diagonal with the given offset is returned, - i.e., the sum of elements a[i,i+offset] for all i. - If `a` has more than two dimensions, then the axes specified by axis1 and - axis2 are used to determine the 2-D sub-arrays whose traces are returned. - The shape of the resulting array is the same as that of a with axis1 and - axis2 removed. + i.e., the sum of elements ``a[i,i+offset]`` for all `i`. + If `a` has more than two dimensions, then the axes specified by `axis1` and + `axis2` are used to determine the 2-D sub-arrays whose traces are returned. + The shape of the resulting array is the same as that of a with `axis1` and + `axis2` removed. Note: - Numpy arguments dtype and out are not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. Args: a (Tensor): Array from which the diagonals are taken. - offset (int): optional. Offset of the diagonal from the main diagonal. - Can be positive or negative. Defaults to main diagonal (0). - axis1 (int): optional. Axis to be used as the first axis of the 2-D + offset (int, optional): Offset of the diagonal from the main diagonal. + Can be positive or negative. Defaults to main diagonal. + axis1 (int, optional): Axis to be used as the first axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to first axis (0). - axis2 (int): optional. Axis to be used as the second axis of the 2-D + axis2 (int, optional): Axis to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to - second axis (1). + second axis. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. Returns: - Tensor, sum_along_diagonals. If a is 2-D, the sum along the diagonal - is returned. If a has larger dimensions, then an array of sums along + Tensor, sum_along_diagonals. If `a` is 2-D, the sum along the diagonal + is returned. If `a` has larger dimensions, then an array of sums along diagonals is returned. Raises: @@ -1181,12 +1156,13 @@ def trace(a, offset=0, axis1=0, axis2=1): """ d = diagonal(a, offset, axis1=axis1, axis2=axis2) shape = F.shape(d) - dtype = F.dtype(d) + if dtype is None: + dtype = F.dtype(d) if shape[-1] == 0: return _empty(dtype, shape[:-1]) cast_type = dtype - if not isinstance(dtype, Float): + if not _check_is_float(dtype): # reduce sum only supports float types cast_type = mstype.float32 d = F.cast(d, cast_type) @@ -1194,3 +1170,483 @@ def trace(a, offset=0, axis1=0, axis2=1): if not _check_same_type(cast_type, dtype): res = F.cast(res, dtype) return res + + +def cumsum(a, axis=None, dtype=None): + """ + Returns the cumulative sum of the elements along a given axis. + + Args: + a (Tensor): Input tensor. + axis (int, optional): Axis along which the cumulative sum is computed. The + default (None) is to compute the cumsum over the flattened array. + dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as `a`, + unless `a` has an integer dtype with a precision less than that of the + default platform integer. In that case, the default platform integer + is used. + + Returns: + Tensor. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If axis is out of range. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.cumsum(np.ones((3,3)), axis=0) + >>> print(output) + [[1. 1. 1.] + [2. 2. 2.] + [3. 3. 3.]] + """ + _check_input_tensor(a) + original_dtype = F.dtype(a) + # If original array is int, and has precision less then int32, convert to int32 + if _check_same_type(original_dtype, mstype.bool_) or \ + _check_same_type(original_dtype, mstype.int8) or \ + _check_same_type(original_dtype, mstype.int16): + a = a.astype(mstype.int32) + if axis is None: + a = a.ravel() + axis = 0 + _check_axis_in_range(axis, a.ndim) + if dtype is not None and not _check_same_type(original_dtype, dtype): + return _cumsum_default(a, axis).astype(dtype, copy=False) + return _cumsum_default(a, axis) + + +def _index(i, size, Cartesian=True): + """If Cartesian=True, index 0 is swapped with index 1.""" + if Cartesian: + if i == 1: + return 0 + if i == 0: + if size >= 2: + return 1 + return i + + +def meshgrid(*xi, sparse=False, indexing='xy'): + """ + Returns coordinate matrices from coordinate vectors. + + Make `N-D` coordinate arrays for vectorized evaluations of `N-D` + scalar/vector fields over `N-D` grids, given one-dimensional + coordinate arrays `x1, x2,…, xn`. + + Note: + Numpy argument copy is not supported, and a copy is always + returned. + + Args: + *xi (Tensor): 1-D arrays representing the coordinates + of a grid. + indexing (‘xy’, ‘ij’, optional): Cartesian (‘xy’, default) or + matrix (‘ij’) indexing of output. In the 2-D case with + inputs of length `M` and `N`, the outputs are of shape `(N, M)` + for ‘xy’ indexing and `(M, N)` for ‘ij’ indexing. In the 3-D + case with inputs of length `M`, `N` and `P`, outputs are of shape + `(N, M, P)` for ‘xy’ indexing and `(M, N, P)` for ‘ij’ indexing. + sparse (bool, optional): If True a sparse grid is returned in + order to conserve memory. Default is False. + + Returns: + Tuple of tensors, for vectors `x1, x2,…, xn` with lengths + ``Ni=len(xi)``, return `(N1, N2, N3,...Nn)` shaped arrays if + ``indexing=’ij’`` or `(N2, N1, N3,...Nn)` shaped arrays if + ``indexing=’xy’`` with the elements of `xi` repeated to fill the matrix + along the first dimension for `x1`, the second for `x2` and so on. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.linspace(0, 1, 3) + >>> y = np.linspace(0, 1, 2) + >>> xv, yv = np.meshgrid(x, y) + >>> print(xv) + [[0. , 0.5, 1. ], + [0. , 0.5, 1. ]] + >>> print(yv) + [[0., 0., 0.], + [1., 1., 1.]] + >>> xv, yv = np.meshgrid(x, y, sparse=True) + >>> print(xv) + [[0. , 0.5, 1. ]] + >>> print(yv) + [[0.], + [1.] + """ + _check_input_tensor(*xi) + + grids = [] + for x in xi: + if F.rank(x) == 1: + grids.append(x) + else: + grids.append(ravel(x)) + ndim = len(grids) + + Cartesian = indexing == 'xy' + shape_out = () + for i in range(len(grids)): + grid_index = _index(i, ndim, Cartesian=Cartesian) + shape_out += (F.shape(grids[grid_index])[0],) + + res = [] + for i, x in enumerate(grids): + grid_index = _index(i, ndim, Cartesian=Cartesian) + shape_expanded = _expanded_shape(ndim, shape_out[grid_index], grid_index) + x = x.reshape(shape_expanded) + if not sparse: + x = F.tile(x, _tile_size(shape_expanded, shape_out, ndim)) + res.append(x) + return res + + +class nd_grid: + """ + Construct a multi-dimensional "meshgrid". + + ``grid = nd_grid()`` creates an instance which will return a mesh-grid + when indexed. + If instantiated with an argument of ``sparse=True``, the mesh-grid is + open (or not fleshed out) so that only one-dimension of each + returned argument is greater than 1. + + Args: + sparse (bool): Whether the grid is sparse or not. Default is + False. + + Returns: + Tensor or tuple of tensor, a meshgrid. If ``sparse=False``, returns + tensors are all of the same dimensions; and if ``sparse=True``, + returns tensors with only one dimension not equal to `1`. + """ + def __init__(self, sparse=False): + self.sparse = sparse + + def __getitem__(self, keys): + if isinstance(keys, slice): + keys = (keys,) + + xi = [] + for k in keys: + if not isinstance(k.start, int) or not isinstance(k.stop, int): + _raise_type_error('slice indices must be integers') + if k.step: + step = k.step + else: + step = 1 + if isinstance(step, complex): + v = linspace(k.start, k.stop, int(abs(step.imag))) + else: + v = arange(k.start, k.stop, step) + xi.append(v) + grids = meshgrid(*xi, sparse=self.sparse, indexing='ij') + + if len(grids) == 1: + return grids[0] + if self.sparse: + return grids + + expanded = [] + for grid in grids: + expanded.append(F.expand_dims(grid, 0)) + res = concatenate(tuple(expanded)) + return res + + +class mGridClass(nd_grid): + """ + mgrid is an :class:`nd_grid` instance with ``sparse=False``. + + The dimension and number of the output arrays are equal to the number + of indexing dimensions. If the step length is not a complex number, + then the stop is not inclusive. However, if the step length is a complex + number (e.g. 5j), then the integer part of its magnitude is interpreted + as specifying the number of points to create between the start and + stop values, where the stop value is inclusive. + + Returns: + Tensor or tuple of tensor, a meshgrid. + + Raises: + TypeError: if slicing indices are not integers. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore.numpy import mgrid + >>> output = mgrid[0:5, 0:5] + >>> print(output) + [[[0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [2, 2, 2, 2, 2], + [3, 3, 3, 3, 3], + [4, 4, 4, 4, 4]], + [[0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4], + [0, 1, 2, 3, 4]]] + >>> output = mgrid[-1:1:5j] + >>> print(output) + [-1. , -0.5, 0. , 0.5, 1. ] + """ + def __init__(self): + super(mGridClass, self).__init__(sparse=False) + + +class oGridClass(nd_grid): + """ + ogrid is an :class:`nd_grid` instance with ``sparse=True``. + + The dimension and number of the output arrays are equal to the number + of indexing dimensions. If the step length is not a complex number, + then the stop is not inclusive. However, if the step length is a complex + number (e.g. 5j), then the integer part of its magnitude is interpreted + as specifying the number of points to create between the start and + stop values, where the stop value is inclusive. + + Raises: + TypeError: if slicing indices are not integers. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> from mindspore.numpy import ogrid + >>> output = ogrid[0:5,0:5] + >>> print(output) + [Tensor(shape=[5, 1], dtype=Int32, value= + [[0], + [1], + [2], + [3], + [4]]), Tensor(shape=[1, 5], dtype=Int32, value= + [[0, 1, 2, 3, 4]])] + >>> output = ogrid[-1:1:5j] + >>> print(output) + [-1. , -0.5, 0. , 0.5, 1. ] + """ + def __init__(self): + super(oGridClass, self).__init__(sparse=True) + + +mgrid = mGridClass() + + +ogrid = oGridClass() + + +def diag(v, k=0): + """ + Extracts a diagonal or construct a diagonal array. + + Args: + v (Tensor): If `v` is a 2-D array, return a copy of its `k-th` diagonal. + If `v` is a 1-D array, return a 2-D array with v on the `k-th` diagonal. + k (int, optional): Diagonal in question. The default is 0. Use ``k>0`` for + diagonals above the main diagonal, and ``k<0`` for diagonals below the + main diagonal. + + Returns: + Tensor, the extracted diagonal or constructed diagonal array. + + Raises: + ValueError: if input is not 1-D or 2-D. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.arange(9).reshape((3,3)) + >>> print(x) + [[0 1 2] + [3 4 5] + [6 7 8]] + >>> output = np.diag(x) + >>> print(output) + [0 4 8] + >>> output = np.diag(x, k=1) + >>> print(output) + [1 5] + >>> output = np.diag(x, k=-1) + >>> print(output) + [3 7] + """ + ndim = F.rank(v) + if ndim == 1: + return diagflat(v, k=k) + if ndim == 2: + shape = F.shape(v) + dtype = F.dtype(v) + if _is_shape_empty(shape): + return _empty(dtype, (0,)) + e = eye(shape[0], shape[1], k, dtype) + prod = F.tensor_mul(v, e) + + cast_type = dtype + if not isinstance(dtype, Float): + # reduce sum only supports float types + cast_type = mstype.float32 + prod = F.cast(prod, cast_type) + + res = F.reduce_sum(prod, 1) + res = res[_max(0, -k): _min(shape[0], _max(0, shape[1] - k))] + + if not _check_same_type(cast_type, dtype): + res = F.cast(res, dtype) + + return res + return _raise_value_error("Input must be 1- or 2-d.") + + +def diagflat(v, k=0): + """ + Creates a two-dimensional array with the flattened input as a diagonal. + + Note: + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + v (Tensor): Input data, which is flattened and set as the `k-th` diagonal + of the output. + k (int, optional): Diagonal to set; 0, the default, corresponds to the + “main” diagonal, a positive (negative) `k` giving the number of the + diagonal above (below) the main. + + Returns: + Tensor, The 2-D output array. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.diagflat(np.asarray([[1,2], [3,4]])) + >>> print(output) + [[1 0 0 0] + [0 2 0 0] + [0 0 3 0] + [0 0 0 4]] + >>> output = np.diagflat(np.asarray([1,2]), 1) + >>> print(output) + [[0 1 0] + [0 0 2] + [0 0 0]] + """ + _check_input_tensor(v) + dtype = F.dtype(v) + k_abs = _abs(k) + if _is_shape_empty(F.shape(v)): + return zeros((k_abs, k_abs), dtype) + + v = ravel(v) + size = F.shape(v)[0] + e = eye(size, size, 0, dtype) + res = F.tensor_mul(v, e) + + if k != 0: + pad_y = zeros((size, k_abs), dtype) + pad_x = zeros((k_abs, size + k_abs), dtype) + if k < 0: + res = concatenate((res, pad_y), axis=1) + res = concatenate((pad_x, res), axis=0) + else: + res = concatenate((pad_y, res), axis=1) + res = concatenate((res, pad_x), axis=0) + return res + + +def diag_indices(n, ndim=2): + """ + Returns the indices to access the main diagonal of an array. + + This returns a tuple of indices that can be used to access the main + diagonal of an array a with ``a.ndim >= 2`` dimensions and shape `(n, n, …, n)`. + For ``a.ndim = 2`` this is the usual diagonal, for ``a.ndim > 2`` this is the set + of indices to access ``a[i, i, ..., i]`` for ``i = [0..n-1]``. + + Args: + n (int): The size, along each dimension, of the arrays for which + the returned indices can be used. + ndim (int, optional): The number of dimensions. + + Returns: + Tuple of Tensor. + + Raises: + TypeError: if input are not integers. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.diag_indices(5, 3) + >>> print(output) + (Tensor(shape=[5], dtype=Int32, value= [0, 1, 2, 3, 4]), + Tensor(shape=[5], dtype=Int32, value= [0, 1, 2, 3, 4]), + Tensor(shape=[5], dtype=Int32, value= [0, 1, 2, 3, 4])) + """ + if not isinstance(n, int) or not isinstance(ndim, int): + _raise_type_error('input must be integers') + return _list_comprehensions(ndim, arange(start=0, stop=n), True) + + +def ix_(*args): + r""" + Constructs an open mesh from multiple sequences. + + This function takes `N` 1-D sequences and returns `N` outputs with `N` + dimensions each, such that the shape is 1 in all but one dimension + and the dimension with the non-unit shape value cycles through all + N dimensions. + Using ix\_ one can quickly construct index arrays that will index + the cross product. ``a[np.ix_([1,3],[2,5])]`` returns the array + ``[[a[1,2] a[1,5]], [a[3,2] a[3,5]]]``. + + Note: + Boolean masks are not supported. + + Args: + *args (Tensor): 1-D, each sequence should be of integer type. + + Returns: + Tuple of Tensor, `N` arrays with `N` dimensions each, with `N` the + number of input sequences. Together these arrays form an open + mesh. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> ixgrid = np.ix_(np.array([0, 1]), np.array([2, 4])) + >>> print(ixgrid) + [Tensor(shape=[2, 1], dtype=Int32, value= + [[0], + [1]]), Tensor(shape=[1, 2], dtype=Int32, value= + [[2, 4]])] + """ + # TODO boolean mask + _check_input_tensor(*args) + ndim = len(args) + res = () + for i, arr in enumerate(args): + if F.rank(arr) != 1: + return _raise_value_error('Cross index must be 1 dimensional') + res += (F.reshape(arr, _expanded_shape(ndim, arr.size, i)),) + return res diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index 939ddbc2b6..53c7e7a72e 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -13,18 +13,24 @@ # limitations under the License. # ============================================================================ """array operations, the function docs are adapted from Numpy API.""" +import operator from ..common import dtype as mstype +from ..common import Tensor from ..ops import operations as P from ..ops import functional as F +from ..ops import composite as C from ..ops.primitive import constexpr from ..nn import Cell -from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to, \ - _is_empty -from .utils_const import _check_is_int, _check_axes_range, _check_start_normalize, \ - _check_is_tensor, _check_is_tuple, _check_is_list, _raise_type_error, _raise_value_error, \ - _infer_out_shape, _empty, _promote, _check_same_type, _check_input_tensor +from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to_shape, \ + _check_input_tensor, _broadcast_to +from .utils_const import _check_axes_range, _check_start_normalize, \ + _raise_type_error, _raise_value_error, _infer_out_shape, _empty, _promote, \ + _check_same_type, _check_axis_valid, _add_unit_axes, _broadcast_tuples, \ + _check_is_float, _check_axis_in_range, _check_axis_type, _canonicalize_axis, \ + _list_comprehensions, _check_element_int, _is_shape_empty, _type_convert, \ + _tuple_getitem, _expanded_shape # According to official numpy reference, the dimension of a numpy array must be less # than 32 @@ -82,11 +88,11 @@ def expand_dims(a, axis): Args: a (Tensor): Input tensor array. - axis Union[int, list(int), tuple(int)]: Position in the expanded axes where - the new axis is placed, + axis (Union[int, list(int), tuple(int)]): Position in the expanded axes where + the new axis is placed, Returns: - Tensor, view of a tensor with the number of dimensions increased. + View of `a` with the number of dimensions increased. Raises: TypeError: If input arguments have types not specified above. @@ -102,8 +108,7 @@ def expand_dims(a, axis): >>> print(x.shape) (1, 2, 2) """ - if not _check_is_tensor(F.typeof(a)): - _raise_type_error("Input is not Tensor.") + _check_input_tensor(a) shape = F.shape(a) # yield expanded shape based on the axes new_shape = _prepare_shape_for_expand_dims(shape, axis) @@ -116,14 +121,14 @@ def squeeze(a, axis=None): Args: a (Tensor): Input tensor array. - axis: Union[None, int, list(int), tuple(list)]. Default is None. + axis (Union[None, int, list(int), tuple(list)]): Default is None. Returns: - Tensor, with all or a subset of the dimensions of length 1 removed. + Tensor, with all or a subset of the dimensions of length :math:`1` removed. Raises: TypeError: If input arguments have types not specified above. - ValueError: If specified axis has shape entry > 1. + ValueError: If specified axis has shape entry :math:`> 1`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -135,8 +140,7 @@ def squeeze(a, axis=None): >>> print(x.shape) (2, 2) """ - if not _check_is_tensor(F.typeof(a)): - _raise_type_error("Input is not Tensor.") + _check_input_tensor(a) return a.squeeze(axis) @@ -146,15 +150,15 @@ def transpose(a, axes=None): Args: a (Tensor): a tensor to be transposed - axes (Union[None, tuple, list]): the axes order, if axes is None, transpose - the entire tensor. Default is None. + axes (Union[None, tuple, list]): the axes order, if `axes` is `None`, transpose + the entire tensor. Default is `None`. Returns: Tensor, the transposed tensor array. Raises: TypeError: If input arguments have types not specified above. - ValueError: If the number of axes is not euqal to a.ndim. + ValueError: If the number of `axes` is not euqal to a.ndim. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -166,8 +170,7 @@ def transpose(a, axes=None): >>> print(x.shape) (3, 2, 1) """ - if not _check_is_tensor(F.typeof(a)): - _raise_type_error("Input is not Tensor.") + _check_input_tensor(a) return a.transpose(axes) @@ -179,33 +182,43 @@ def rollaxis(x, axis, start=0): Args: x (Tensor): A Tensor to be transposed. axis (int): The axis to be rolled. - start (int): - - When start >= 0: - - When start <= axis: the axis is rolled back until it lies in - this position (start). - - When start > axis: the axis is rolled until it lies before this - position (start). - - When start < 0: the start will be normalized as follows: - start ........... Normalized start - -(x.ndim+1) raise ValueError - -x.ndim 0 - ... ... - -1 x.ndim-1 - 0 0 - ... ... - x.ndim x.ndim - x.ndim+1 raise ValueError + start (int): Default: 0. + If :math:`start <= axis`, the axis is rolled back until it lies in this position (`start`). + If :math:`start > axis`: the axis is rolled until it lies before this position (`start`). + + If :math:`start < 0`, the start will be normalized as shown in the table. + (Please refer to the source code.) + + .. table + +===========+=================+ + |start |Normalized start | + +===========+=================+ + |-(x.ndim+1)| raise ValueError| + +-----------+-----------------+ + |-x.ndim |0 | + +-----------+-----------------+ + |... |... | + +-----------+-----------------+ + |-1 |x.ndim-1 | + +-----------+-----------------+ + |... |... | + +-----------+-----------------+ + |x.ndim |x.ndim | + +-----------+-----------------+ + |x.ndim+1 |raise ValueError | + +===========+=================+ + .. Returns: - Transposed Tensor. Has the same data type as the original tensor x. + Transposed Tensor. Has the same data type as the original tensor `x`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If axis or start is not integer, or x is not tensor. - ValueError: If axis is not in the range from -ndim to ndim-1 or - start is not in the range from -ndim to ndim. + TypeError: If `axis` or `start` is not integer, or `x` is not tensor. + ValueError: If `axis` is not in the range of :math:`[-ndim, ndim-1]` or + `start` is not in the range of :math:`[-ndim, ndim]`. Examples: >>> import mindspore.numpy as np @@ -214,11 +227,10 @@ def rollaxis(x, axis, start=0): >>> print(output.shape) (3, 2, 4) """ - if not _check_is_tensor(F.typeof(x)): - _raise_type_error("Input is not Tensor.") - if not _check_is_int(axis): + _check_input_tensor(x) + if not isinstance(axis, int): _raise_type_error("integer argument expected, but got ", axis) - if not _check_is_int(start): + if not isinstance(start, int): _raise_type_error("integer argument expected, but got ", start) shape = F.shape(x) @@ -257,11 +269,11 @@ def swapaxes(x, axis1, axis2): axis2 (int): Second axis. Returns: - Transposed tensor, has the same data type as the original tensor x. + Transposed tensor, has the same data type as the original tensor `x`. Raises: - TypeError: If axis1 or axis2 is not integer, or x is not tensor. - ValueError: If axis1 or axis2 is not in the range from -ndim to ndim-1. + TypeError: If `axis1` or `axis2` is not integer, or `x` is not tensor. + ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -273,8 +285,7 @@ def swapaxes(x, axis1, axis2): >>> print(output.shape) (4,3,2) """ - if not _check_is_tensor(F.typeof(x)): - _raise_type_error("Input is not Tensor.") + _check_input_tensor(x) return x.swapaxes(axis1, axis2) @@ -287,15 +298,15 @@ def reshape(x, new_shape): new_shape (Union[int, list(int), tuple(int)]): The new shape should be compatible with the original shape. If the tuple has only one element, the result will be a 1-D tensor of that length. One shape dimension - can be -1. In this case, the value is inferred from the length of + can be :math:`-1`. In this case, the value is inferred from the length of the tensor and remaining dimensions. Returns: - Reshaped Tensor. Has the same data type as the original tensor x. + Reshaped Tensor. Has the same data type as the original tensor `x`. Raises: - TypeError: If new_shape is not integer, list or tuple, or x is not tensor. - ValueError: If new_shape does not compatible with the original shape. + TypeError: If new_shape is not integer, list or tuple, or `x` is not tensor. + ValueError: If new_shape is not compatible with the original shape. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -317,8 +328,7 @@ def reshape(x, new_shape): >>> print(output) [-0.1 0.3 3.6 0.4 0.5 -3.2] """ - if not _check_is_tensor(F.typeof(x)): - _raise_type_error("Input is not Tensor.") + _check_input_tensor(x) return x.reshape(new_shape) @@ -332,10 +342,10 @@ def ravel(x): x (Tensor): A tensor to be flattened. Returns: - Flattened tensor, has the same data type as the original tensor x. + Flattened tensor, has the same data type as the original tensor `x`. Raises: - TypeError: If x is not tensor. + TypeError: If `x` is not tensor. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -347,8 +357,7 @@ def ravel(x): >>> print(output.shape) (24,) """ - if not _check_is_tensor(F.typeof(x)): - _raise_type_error("Input is not Tensor.") + _check_input_tensor(x) return x.ravel() @@ -399,18 +408,17 @@ def concatenate(arrays, axis=0): Joins a sequence of tensors along an existing axis. Args: - arrays: Union[Tensor, tuple(Tensor), list(Tensor)], a tensor or a list - of tensors to be concatenated. - - axis (int, optional): The axis along which the tensors will be joined, - if axis is None, tensors are flattened before use. Default is 0. + arrays (Union[Tensor, tuple(Tensor), list(Tensor)]): a tensor or a list + of tensors to be concatenated. + axis (Union[None, int], optional): The axis along which the tensors will be joined, + if `axis` is :class:`None`, tensors are flattened before use. Default is 0. Returns: - Tensor, a tensor concatenated from a tensor or a list of tensors. + A tensor concatenated from a tensor or a list of tensors. Raises: TypeError: If input arguments have types not specified above. - ValueError: If specified axis < 0, and exceeds tensor.ndim. + ValueError: If specified `axis` < 0, or exceeds tensor.ndim. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -423,9 +431,7 @@ def concatenate(arrays, axis=0): >>> print(x.shape) (1, 2, 4) """ - array_type = F.typeof(arrays) - if _check_is_tensor(array_type): - # if the input is a single tensor + if isinstance(arrays, Tensor): # if only one tensor is provided, it is treated as a tuple along the # first dimension. For example, a tensor of shape (3,4,5) will be treated # as: tuple(tensor_1(4,5), tensor_2(4,5), tensor_3(4,5)) @@ -462,6 +468,47 @@ def concatenate(arrays, axis=0): return P.Concat(axis)(arrays) +def append(arr, values, axis=None): + """ + Appends values to the end of a tensor. + + Args: + arr (Tensor): Values are appended to a copy of this tensor. + values (Tensor): These values are appended to a copy of `arr`. It must be of + the correct shape (the same shape as `arr`, excluding `axis`). If `axis` is + not specified, `values` can be any shape and will be flattened before use. + axis (None, int, optional): The `axis` along which values are appended. If `axis` is not + given, both `arr` and `values` are flattened before use, default is :class:`None`. + + Returns: + Tensor, a copy of tensor with values appended to axis. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If specified axis exceeds `arr.ndim`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.ones((2, 3)) + >>> b = np.ones((2, 1)) + >>> print(np.append(a, b, axis=1).shape) + >>> (2, 4) + """ + _check_input_tensor(arr) + _check_input_tensor(values) + if axis is None: + arr = arr.ravel() + values = values.ravel() + else: + _check_axis_in_range(axis, arr.ndim) + if F.rank(arr) != F.rank(values): + _raise_value_error("all tensors must have same number of dimensions") + return concatenate((arr, values), axis) + + def column_stack(tup): """ Stacks 1-D tensors as columns into a 2-D tensor. 2-D tensors are stacked as-is, @@ -478,8 +525,8 @@ def column_stack(tup): ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If tup is not Tensor, list or tuple. - ValueError: If tup is empty. + TypeError: If `tup` is not Tensor, list or tuple. + ValueError: If `tup` is empty. Examples: >>> import mindspore.numpy as mnp @@ -493,9 +540,9 @@ def column_stack(tup): [2, 5], [3, 6]] """ - if _check_is_tensor(F.typeof(tup)): + if isinstance(tup, Tensor): return tup - if not _check_is_list(tup) and not _check_is_tuple(tup): + if not isinstance(tup, (list, tuple)): _raise_type_error("Tensor or, list or tuple of tensors are required, but got ", tup) trans_tup = () @@ -513,7 +560,7 @@ def column_stack(tup): def vstack(tup): """ Stacks tensors in sequence vertically. - This is equivalent to concatenation along the first axis. 1-D tensors should firstly be reshaped to (1, N), + This is equivalent to concatenation along the first axis. 1-D tensors should firstly be reshaped to `(1, N)`, and then be concatenated along the first axis. Args: @@ -527,8 +574,8 @@ def vstack(tup): ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If tup is not Tensor, list or tuple. - ValueError: If tup is empty. + TypeError: If `tup` is not Tensor, list or tuple. + ValueError: If `tup` is empty. Examples: >>> import mindspore.numpy as mnp @@ -541,9 +588,9 @@ def vstack(tup): [[1, 2, 3], [4, 5, 6]] """ - if _check_is_tensor(F.typeof(tup)): + if isinstance(tup, Tensor): return tup - if not _check_is_list(tup) and not _check_is_tuple(tup): + if not isinstance(tup, (list, tuple)): _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) trans_tup = () @@ -560,7 +607,7 @@ def hstack(tup): """ Stacks tensors in sequence horizontally. This is equivalent to concatenation along the second axis, except for 1-D tensors - where it concatenates along the first axis. + where it concatenates along the first axis. Args: tup (Union[Tensor, tuple, list]): A sequence of 1-D or 2-D tensors. The @@ -574,22 +621,20 @@ def hstack(tup): ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If tup is not Tensor, list or tuple. - ValueError: If tup is empty. + TypeError: If `tup` is not Tensor, list or tuple. + ValueError: If `tup` is empty. Examples: - >>> import mindspore.numpy as mnp - >>> import numpy as onp - >>> from mindspore import Tensor - >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32')) - >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32')) - >>> output = mnp.hstack((x1, x2)) + >>> import mindspore.numpy as np + >>> x1 = np.array([1, 2, 3]).astype('float32') + >>> x2 = np.array([4, 5, 6]).astype('float32') + >>> output = np.hstack((x1, x2)) >>> print(output) - [1, 2, 3, 4, 5, 6] + [1. 2. 3. 4. 5. 6.] """ - if _check_is_tensor(F.typeof(tup)): + if isinstance(tup, Tensor): return tup - if not _check_is_list(tup) and not _check_is_tuple(tup): + if not isinstance(tup, (list, tuple)): _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) tuple_of_tensor = () @@ -607,8 +652,9 @@ def hstack(tup): def dstack(tup): """ Stacks tensors in sequence depth wise (along the third axis). - This is equivalent to concatenation along the third axis. 1-D tensors (N,) should be reshaped to (1,N,1). - 2-D tensors (M,N) should be reshaped to (M,N,1) before concatenation. + This is equivalent to concatenation along the third axis. 1-D tensors :math:`(N,)` should be + reshaped to :math:`(1,N,1)`. + 2-D tensors :math:`(M,N)` should be reshaped to :math:`(M,N,1)` before concatenation. Args: tup (Union[Tensor, tuple, list]): A sequence of tensors. The tensors must have the same shape along all but @@ -621,25 +667,23 @@ def dstack(tup): ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If tup is not Tensor, list or tuple. - ValueError: If tup is empty. + TypeError: If `tup` is not Tensor, list or tuple. + ValueError: If `tup` is empty. Examples: - >>> import mindspore.numpy as mnp - >>> import numpy as onp - >>> from mindspore import Tensor - >>> x1 = Tensor(onp.array([1, 2, 3]).astype('int32')) - >>> x2 = Tensor(onp.array([4, 5, 6]).astype('int32')) - >>> output = mnp.dstack((x1, x2)) + >>> import mindspore.numpy as np + >>> x1 = np.array([1, 2, 3]).astype('float32') + >>> x2 = np.array([4, 5, 6]).astype('float32') + >>> output = np.dstack((x1, x2)) >>> print(output) - [[[1, 4], - [2, 5], - [3, 6]]] + [[[1. 4.] + [2. 5.] + [3. 6.]]] """ - if _check_is_tensor(F.typeof(tup)): + if isinstance(tup, Tensor): return tup - if not _check_is_list(tup) and not _check_is_tuple(tup): - _raise_type_error("Tensor or, list or tuple of tensors are required, but got", tup) + if not isinstance(tup, (list, tuple)): + _raise_type_error("Tensor or list or tuple of tensors are required, but got", tup) trans_tup = () for tensor in tup: @@ -655,19 +699,20 @@ def dstack(tup): def where(condition, x=None, y=None): """ - Returns elements chosen from x or y depending on condition. + Returns elements chosen from `x` or `y` depending on `condition`. Note: - As nonzero is not supported, neither x or y can be None. + As nonzero is not supported, neither `x` or `y` can be None. Args: - condition (Tensor): where True, yield x, otherwise yield y. - x, y (Tensor): Values from which to choose. x, y and condition need - to be broadcastable to some shape. + condition (Tensor): where True, yield `x`, otherwise yield `y`. + x (Tensor) + y (Tensor): Values from which to choose. `x`, `y` and `condition` need + to be broadcastable to some shape. Returns: - Tensor or scalar, with elements from x where condition is True, and - elements from y elsewhere. + Tensor or scalar, with elements from `x` where `condition` is True, and + elements from `y` elsewhere. Raises: ValueError: if operands cannot be broadcast. @@ -685,8 +730,7 @@ def where(condition, x=None, y=None): [[[7, 5], [7, 5], [7, 5]], - - [[7, 5], + [[7, 5], [7, 5], [7, 5]]] """ @@ -708,17 +752,13 @@ def where(condition, x=None, y=None): # broadcasts input tensors shape_out = _infer_out_shape(F.shape(condition), F.shape(x), F.shape(y)) - ndim_out = len(shape_out) if not _check_same_type(F.dtype(condition), mstype.float32): # tiling with bool is not supported on GPU condition = F.cast(condition, mstype.float32) - condition = _expand(condition, ndim_out) - x = _expand(x, ndim_out) - y = _expand(y, ndim_out) - condition = _broadcast_to( - condition, F.shape(condition), shape_out, ndim_out) - x = _broadcast_to(x, F.shape(x), shape_out, ndim_out) - y = _broadcast_to(y, F.shape(y), shape_out, ndim_out) + condition = _broadcast_to_shape(condition, shape_out) + x = _broadcast_to_shape(x, shape_out) + y = _broadcast_to_shape(y, shape_out) + if not _check_same_type(F.dtype(condition), mstype.bool_): condition = F.cast(condition, mstype.bool_) res = F.select(condition, x, y) @@ -729,8 +769,7 @@ def where(condition, x=None, y=None): def _atleast_xd(ndim, arys): """Returns arys with at least ndim.""" - for arr in arys: - _check_input_tensor(F.typeof(arr)) + _check_input_tensor(*arys) res = [] for arr in arys: arr = _expand(arr, ndim) @@ -750,11 +789,12 @@ def atleast_1d(*arys): Note: In graph mode, returns a tuple of tensor instead of a list of tensors. + Args: - arys1, arys2, … (Tensor): one or more input tensors. + *arys (Tensor): one or more input tensors. Returns: - Tensor, or list of tensors, each with a.ndim >= 1. + Tensor, or list of tensors, each with ``a.ndim >= 1``. Raises: TypeError: if the input is not a tensor. @@ -763,6 +803,7 @@ def atleast_1d(*arys): ``Ascend`` ``GPU`` ``CPU`` Examples: + >>> import mindspore.numpy as np >>> a = np.ones((2, 3)) >>> b = np.ones(()) >>> c = np.ones(5) @@ -787,10 +828,10 @@ def atleast_2d(*arys): In graph mode, returns a tuple of tensor instead of a list of tensors. Args: - arys1, arys2, … (Tensor): one or more input tensors. + *arys (Tensor): one or more input tensors. Returns: - Tensor, or list of tensors, each with a.ndim >= 2. + Tensor, or list of tensors, each with ``a.ndim >= 2``. Raises: TypeError: if the input is not a tensor. @@ -799,6 +840,7 @@ def atleast_2d(*arys): ``Ascend`` ``GPU`` ``CPU`` Examples: + >>> import mindspore.numpy as np >>> a = np.ones((2, 3)) >>> b = np.ones(()) >>> c = np.ones(5) @@ -824,12 +866,12 @@ def atleast_3d(*arys): tensors. Args: - arys1, arys2, … (Tensor): one or more input tensors. + *arys (Tensor): one or more input tensors. Returns: - Tensor, or list of tensors, each with a.ndim >= 3. For example, - a 1-D array of shape (N,) becomes a view of shape (1, N, 1), and - a 2-D array of shape (M, N) becomes a view of shape (M, N, 1). + Tensor, or list of tensors, each with ``a.ndim >= 3``. For example, + a 1-D array of shape `(N,)` becomes a view of shape `(1, N, 1)`, and + a 2-D array of shape `(M, N)` becomes a view of shape `(M, N, 1)`. Raises: TypeError: if the input is not a tensor. @@ -838,6 +880,7 @@ def atleast_3d(*arys): ``Ascend`` ``GPU`` ``CPU`` Examples: + >>> import mindspore.numpy as np >>> a = np.ones((2, 3)) >>> b = np.ones(()) >>> c = np.ones(5) @@ -870,17 +913,16 @@ def stack(arrays, axis=0): """ Joins a sequence of arrays along a new axis. - The axis parameter specifies the index of the new axis in the - dimensions of the result. For example, if axis=0 it will be the - first dimension and if axis=-1 it will be the last dimension. - + The `axis` parameter specifies the index of the new axis in the + dimensions of the result. For example, if ``axis=0`` it will be the + first dimension and if ``axis=-1`` it will be the last dimension. Note: Numpy argument out is not supported. Args: arrays (sequence of Tensor): Each array must have the same shape. - axis (int): optional. The axis in the result array along which the + axis (int, optional): The axis in the result array along which the input arrays are stacked. Returns: @@ -894,6 +936,7 @@ def stack(arrays, axis=0): ``Ascend`` ``GPU`` ``CPU`` Examples: + >>> import mindspore.numpy as np >>> arrays = [np.ones((3, 4)) for _ in range(10)] >>> output = np.stack(arrays, axis=0) >>> print(output.shape) @@ -905,23 +948,22 @@ def stack(arrays, axis=0): >>> print(output.shape) (3, 4, 10) """ - arr_type = F.typeof(arrays) - if _check_is_tensor(arr_type): + if isinstance(arrays, Tensor): shape = F.shape(arrays) ndim = F.rank(arrays) axis = axis % ndim axes = F.make_range(ndim) perm = axes[1:axis+1] + (0,) + axes[axis+1:] - if _is_empty(shape): + if _is_shape_empty(shape): return _empty(mstype.float32, shape[1:axis+1] + (shape[0],) + shape[axis+1:]) return transpose(arrays, perm) - if _check_is_tuple(arr_type) or _check_is_list(arr_type): + if isinstance(arrays, (list, tuple)): shape = (len(arrays),) + F.shape(arrays[0]) ndim = len(shape) axis = axis % ndim - if _is_empty(shape): + if _is_shape_empty(shape): return _empty(mstype.float32, shape[1:axis+1] + (shape[0],) + shape[axis+1:]) seq = () for arr in arrays: @@ -931,7 +973,7 @@ def stack(arrays, axis=0): class UniqueNet(Cell): - """The operation `mindspore.ops.Unique` must be wrapped inside a model and executed in graph mode. """ + """The operation is wrapped inside a model. """ def __init__(self): super(UniqueNet, self).__init__() @@ -948,40 +990,38 @@ def unique(x, return_inverse=False): Note: Numpy arguments `axis`, `return_index` and `return_counts` are not supported. - This operator must be executed in graph mode. + On CPU, this operator must be executed in graph mode. Args: x (Tensor): The input tensor to be processed. - return_inverse (bool): If True, also return the indices of the unique tensor. - Default: False. + return_inverse (bool): If `True`, also return the indices of the unique tensor. + Default: `False`. Returns: Tensor or tuple of Tensors. - - If `return_inverse` is False, just return the unique tensor. - - If `return_inverse` is True, return tuple of tensors. + - If `return_inverse` is `False`, just return the unique tensor. + - If `return_inverse` is `True`, return tuple of tensors. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Raises: - TypeError: If x is not tensor. + TypeError: If `x` is not tensor. Examples: - >>> import mindspore.numpy as mnp - >>> import numpy as onp + >>> import mindspore.numpy as np >>> from mindspore import context >>> context.set_context(mode=context.GRAPH_MODE) - >>> input_x = mnp.asarray(onp.array([1, 2, 2, 2, 3, 4, 5]).astype('int32')) - >>> output_x = mnp.unique(input_x) + >>> input_x = np.asarray([1, 2, 2, 2, 3, 4, 5]).astype('int32') + >>> output_x = np.unique(input_x) >>> print(output_x) [1, 2, 3, 4, 5] - >>> output_x = mnp.unique(input_x, return_inverse=True) + >>> output_x = np.unique(input_x, return_inverse=True) >>> print(output_x) (Tensor(shape=[5], dtype=Int32, value= [ 1, 2, 3, 4, 5]), Tensor(shape=[7], dtype=Int32, value= [0, 1, 1, 1, 2, 3, 4])) """ - if not _check_is_tensor(F.typeof(x)): - _raise_type_error("Tensor is expected, but got", x) + _check_input_tensor(x) if F.tuple_len(F.shape(x)) > 1: x = ravel(x) uniq = UniqueNet() @@ -989,3 +1029,887 @@ def unique(x, return_inverse=False): if not return_inverse: return res[0] return res + + +def roll_along_axis(a, shift, axis): + """ + Rolls a tensor along a given axis. This is a helper function of np.roll. + + Args: + a (Tensor): Input tensor. + shift (int): The number of places the tensor is shifted. + axis (int): The designated axis for shifting. + + Returns: + Shifted tensor. + """ + _check_axis_in_range(axis, a.ndim) + _check_element_int((shift, axis)) + if axis < 0: + axis += a.ndim + shift = -(shift % a.shape[axis]) + # if shift is 0, we do not need to roll at all + if shift == 0: + return a + begin1 = () + begin2 = () + end1 = () + end2 = () + stride = _list_comprehensions(a.ndim, 1, True) + for i in F.make_range(a.ndim): + if i != axis: + begin1 += (0,) + end1 += (a.shape[i],) + begin2 += (0,) + end2 += (a.shape[i],) + else: + begin1 += (shift,) + end1 += (a.shape[i],) + begin2 += (0,) + end2 += (shift,) + return append(F.strided_slice(a, begin1, end1, stride), + F.strided_slice(a, begin2, end2, stride), axis=axis) + + +def roll(a, shift, axis=None): + """ + Rolls a tensor along given axes. + + Elements that rolls beyond the last position are re-introduced at the first. + + Args: + a (Tensor): Input tensor. + shift (Union[int, tuple(int)]: The number of places by which elements are + shifted. If a tuple, then axis must be a tuple of the same size, and + each of the given axes is shifted by the corresponding number. If shift + is an int while axis is a tuple of ints, then the same value is used + for all given axes. + axis (Union[int, tuple(int)], optional): Axis or axes along which elements + are shifted. By default, the array is flattened before shifting, after + which the original shape is restored. + + Returns: + Tensor, with the same shape as a. + + Raises: + TypeError: If input arguments have types not specified above. + ValueError: If axis exceeds `a.ndim`, or `shift` and `axis` cannot broadcast. + + Examples: + >>> import mindspore.numpy as np + >>> a = np.reshape(np.arange(12), (3, 4)) + >>> print(np.roll(a, [2,-3], [0,-1])) + [[ 7 4 5 6] + [11 8 9 10] + [ 3 0 1 2]] + """ + _check_input_tensor(a) + original_shape = a.shape + original_dtype = a.dtype + restore_shape = False + # F.strided_slice only supports float on cpu, this will change once more supports + # are added. + if not _check_is_float(original_dtype): + a = a.astype(mstype.float32) + if axis is None: + restore_shape = True + axis = 0 + a = a.ravel() + # Broadcast shift and axis to the same length + shift, axis = _broadcast_tuples(shift, axis) + for shift_each, axis_each in zip(shift, axis): + a = roll_along_axis(a, shift_each, axis_each) + if restore_shape: + a = a.reshape(original_shape) + if not _check_is_float(original_dtype): + a = a.astype(original_dtype) + return a + + +@constexpr +def _get_moved_perm(ndim, source, destination): + """ + Helper function for moveaxis, returns permutation after moving axes + from source to destination. + """ + dest_sorted_idx = [i for i, _ in sorted(enumerate(destination), + key=operator.itemgetter(1))] + axes_orig = [i for i in range(ndim) if i not in source] + + k = 0 + m = 0 + perm = [] + for i in dest_sorted_idx: + # inserts an axis that has been moved, denoted by n, and axes that remain + # in their original position, indexed from k to k + n - m, into index m in + # the list of permuted axes + n = destination[i] + j = k + n - m + perm += axes_orig[k:j] + perm.append(source[i]) + k += n - m + m = n + 1 + perm += axes_orig[k:] + return tuple(perm) + + +@constexpr +def _get_moved_shape(shape, perm): + """ + Helper function for moveaxis, returns the permuated shape after + applying perm. + """ + return tuple([shape[i] for i in perm]) + + +def moveaxis(a, source, destination): + """ + Moves axes of an array to new positions. + + Other axes remain in their original order. + + Args: + a (Tensor): The array whose axes should be reordered. + source (int or sequence of ints): Original positions of the + axes to move. These must be unique. + destination (int or sequence of ints): Destination positions + for each of the original axes. These must also be unique. + + Returns: + Tensor, array with moved axes. + + Raises: + ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or + if the axes contain duplicates. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.zeros((3, 4, 5)) + >>> output = np.moveaxis(x, 0, -1) + >>> print(output.shape) + (4, 5, 3) + >>> output = np.moveaxis(x, -1, 0) + >>> print(output.shape) + (5, 3, 4) + >>> output = np.moveaxis(x, [0, 1, 2], [-1, -2, -3]) + >>> print(output.shape) + (5, 4, 3) + """ + ndim = F.rank(a) + source = _check_axis_valid(source, ndim) + destination = _check_axis_valid(destination, ndim) + perm = _get_moved_perm(ndim, source, destination) + + shape = F.shape(a) + if _is_shape_empty(shape): + return _empty(F.dtype(a), _get_moved_shape(shape, perm)) + + return F.transpose(a, perm) + + +@constexpr +def _seq_prod(seq1, seq2): + """Returns the element-wise product of seq1 and seq2.""" + return tuple(map(lambda x, y: x*y, seq1, seq2)) + + +def tile(a, reps): + """ + Constructs an array by repeating `a` the number of times given by `reps`. + + If `reps` has length `d`, the result will have dimension of ``max(d, a.ndim)``. + If ``a.ndim < d``, `a` is promoted to be d-dimensional by prepending new axes. + So a shape (3,) array is promoted to (1, 3) for 2-D replication, or + shape (1, 1, 3) for 3-D replication. If this is not the desired behavior, + promote `a` to d-dimensions manually before calling this function. + If ``a.ndim > d``, `reps` is promoted to ``a.ndim`` by pre-pending 1’s to it. Thus + for an `a` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as (1, 1, 2, 2). + + Args: + a (Tensor): The input array. + reps (int or sequence of ints): The number of repetitions of `a` along + each axis. + + Returns: + Tensor, the tiled output array. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([0, 1, 2]) + >>> output = np.tile(a, 2) + >>> print(output) + [0 1 2 0 1 2] + >>> output = np.tile(a, (2, 2)) + >>> print(output) + [[0 1 2 0 1 2] + [0 1 2 0 1 2]] + >>> output = np.tile(a, (2, 1, 2)) + >>> print(output) + [[[0 1 2 0 1 2]] + [[0 1 2 0 1 2]]] + """ + _check_input_tensor(a) + ndim = F.rank(a) + shape = F.shape(a) + reps = _add_unit_axes(reps, ndim) + if _is_shape_empty(shape) or _is_shape_empty(reps): + shape = _add_unit_axes(shape, len(reps)) + return _empty(F.dtype(a), _seq_prod(shape, reps)) + return F.tile(a, reps) + + +@constexpr +def _check_can_broadcast_to(shape, target_shape): + """Determines if shape can be broadcast to target_shape.""" + ndim = len(shape) + ndim_target = len(target_shape) + if ndim > ndim_target: + return False + for i, j in zip(reversed(shape), reversed(target_shape)): + if i not in (1, j): + return False + return True + + +def broadcast_to(array, shape): + """ + Broadcasts an array to a new shape. + + Args: + array (Tensor): The array to broadcast. + shape (tuple): The shape of the desired array. + + Returns: + Tensor, original array broadcast to the given shape. + + Raises: + ValueError: if array cannot be broadcast to shape. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Example: + >>> x = np.array([1, 2, 3]) + >>> output = np.broadcast_to(x, (3, 3)) + >>> print(output) + [[1 2 3] + [1 2 3] + [1 2 3]] + """ + shape_a = F.shape(array) + if not _check_can_broadcast_to(shape_a, shape): + return _raise_value_error('cannot broadcaast with {shape_a} {shape}') + return _broadcast_to_shape(array, shape) + + +def broadcast_arrays(*args): + """ + Broadcasts any number of arrays against each other. + + Note: + Numpy argument `subok` is not supported. + In graph mode, returns a tuple of Tensor instead of a list + of Tensor. + + Args: + *args (Tensor): The arrays to broadcast. + + Returns: + List of Tensor. + + Raises: + ValueError: if arrays cannot be broadcast. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Example: + >>> x = np.array([[1,2,3]]) + >>> y = np.array([[4],[5]]) + >>> output = np.broadcast_arrays(x, y) + >>> print(output) + [Tensor(shape=[2, 3], dtype=Int32, value= + [[1, 2, 3], + [1, 2, 3]]), Tensor(shape=[2, 3], dtype=Int32, value= + [[4, 4, 4], + [5, 5, 5]])] + """ + shapes = map(F.shape, args) + out_shape = _infer_out_shape(*shapes) + res = [] + for arr in args: + res.append(broadcast_to(arr, out_shape)) + return res + + +def split(x, indices_or_sections, axis=0): + """ + Splits a tensor into multiple sub-tensors along the given axis. + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): + If integer, :math:`N`, the tensor will be divided into + :math:`N` equal tensors along axis. + If tuple(int), list(int) or of sorted integers, + the entries indicate where along axis the array is split. + For example, :math:`[2, 3]` would, for :math:`axis=0`, result in + three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`. + If an index exceeds the dimension of the array along axis, + an empty sub-array is returned correspondingly. + axis (int): The axis along which to split. Default: 0. + + Returns: + A list of sub-tensors. + + Raises: + TypeError: If argument `indices_or_sections` is not integer, + tuple(int) or list(int) or argument `axis` is not integer. + ValueError: If argument `axis` is out of range of :math:`[-x.ndim, x.ndim)`. + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.arange(9) + >>> output = np.split(input_x, 3) + >>> print(output) + (Tensor(shape=[3], dtype=Float32, + value= [ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]), + Tensor(shape=[3], dtype=Float32, + value= [ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]), + Tensor(shape=[3], dtype=Float32, + value= [ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00])) + """ + _ = _check_axis_type(axis, True, False, False) + axis = _canonicalize_axis(axis, x.ndim) + res = None + if isinstance(indices_or_sections, int): + _split = P.Split(axis, indices_or_sections) + res = _split(x) + elif isinstance(indices_or_sections, (list, tuple)) and _check_element_int(indices_or_sections): + res = _split_sub_tensors(x, indices_or_sections, axis) + else: + _raise_type_error("Argument `indices_or_sections` in `mindspore.numpy.split`\ + should be integer, tuple(int) or list(int), but got", indices_or_sections) + return res + + +def _split_sub_tensors(x, indices, axis): + """ + Splits the input tensor `x` into multiple sub-tensors + along the axis according to the given indices. + """ + if indices[-1] < x.shape[axis]: + if isinstance(indices, list): + indices.append(x.shape[axis]) + elif isinstance(indices, tuple): + indices += (x.shape[axis],) + + sub_tensors = [] + strides = _list_comprehensions(x.ndim, 1, True) + begin = _list_comprehensions(x.ndim, 0) + end = _list_comprehensions(x.shape) + for i, idx in enumerate(indices): + begin[axis] = 0 if i == 0 else indices[i-1] + end[axis] = idx + sliced_tensor = F.strided_slice(x, _type_convert(tuple, begin), _type_convert(tuple, end), strides) + sub_tensors.append(sliced_tensor) + return sub_tensors + + +def vsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors vertically (row-wise). + It is equivalent to split with :math:`axis=0` (default), the array is always + split along the first axis regardless of the array dimension. + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): + If integer, :math:`N`, the tensor will be divided into + :math:`N` equal tensors along axis. + If tuple(int), list(int) or of sorted integers, + the entries indicate where along axis the array is split. + For example, :math:`[2, 3]` would, for :math:`axis=0`, result in + three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`. + If an index exceeds the dimension of the array along axis, + an empty sub-array is returned correspondingly. + + Returns: + A list of sub-tensors. + + Raises: + TypeError: If argument `indices_or_sections` is not integer. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.arange(9).reshape((3, 3)) + >>> output = np.vsplit(input_x, 3) + >>> print(output) + (Tensor(shape=[1, 3], dtype=Float32, + value=[[ 0.00000000e+00, 1.00000000e+00, 2.00000000e+00]]), + Tensor(shape=[1, 3], dtype=Float32, + value=[[ 3.00000000e+00, 4.00000000e+00, 5.00000000e+00]]), + Tensor(shape=[1, 3], dtype=Float32, + value=[[ 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])) + """ + return split(x, indices_or_sections, 0) + + +def hsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors horizontally (column-wise). + It is equivalent to split with :math:`axis=1` (default), the array is always + split along the second axis regardless of the array dimension. + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): + If integer, :math:`N`, the tensor will be divided into + :math:`N` equal tensors along axis. + If tuple(int), list(int) or of sorted integers, + the entries indicate where along axis the array is split. + For example, :math:`[2, 3]` would, for :math:`axis=0`, result in + three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`. + If an index exceeds the dimension of the array along axis, + an empty sub-array is returned correspondingly. + + Returns: + A list of sub-tensors. + + Raises: + TypeError: If argument `indices_or_sections` is not integer. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.arange(6).reshape((2, 3)) + >>> output = np.hsplit(input_x, 3) + >>> print(output) + (Tensor(shape=[2, 1], dtype=Float32, + value=[[ 0.00000000e+00], + [ 3.00000000e+00]]), + Tensor(shape=[2, 1], dtype=Float32, + value=[[ 1.00000000e+00], + [ 4.00000000e+00]]), + Tensor(shape=[2, 1], dtype=Float32, + value=[[ 2.00000000e+00], + [ 5.00000000e+00]])) + """ + return split(x, indices_or_sections, 1) + + +def dsplit(x, indices_or_sections): + """ + Splits a tensor into multiple sub-tensors along the 3rd axis (depth). + It is equivalent to split with :math:`axis=2` (default), the array is always + split along the third axis regardless of the array dimension. + + Args: + x (Tensor): A Tensor to be divided. + indices_or_sections (Union[int, tuple(int), list(int)]): + If integer, :math:`N`, the tensor will be divided into + :math:`N` equal tensors along axis. + If tuple(int), list(int) or of sorted integers, + the entries indicate where along axis the array is split. + For example, :math:`[2, 3]` would, for :math:`axis=0`, result in + three sub-tensors :math:`x[:2]`, :math:`x[2:3]`and :math:`x[3:]`. + If an index exceeds the dimension of the array along axis, + an empty sub-array is returned correspondingly. + + Returns: + A list of sub-tensors. + + Raises: + TypeError: If argument `indices_or_sections` is not integer. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.arange(6).reshape((1, 2, 3)) + >>> output = np.dsplit(input_x, 3) + >>> print(output) + (Tensor(shape=[1, 2, 1], dtype=Float32, + value=[[[ 0.00000000e+00], + [ 3.00000000e+00]]]), + Tensor(shape=[1, 2, 1], dtype=Float32, + value=[[[ 1.00000000e+00], + [ 4.00000000e+00]]]), + Tensor(shape=[1, 2, 1], dtype=Float32, + value=[[[ 2.00000000e+00], + [ 5.00000000e+00]]])) + """ + return split(x, indices_or_sections, 2) + + +@constexpr +def _get_flip_start(ndim, shape, axes): + return tuple([shape[i] - 1 if i in axes else 0 for i in range(ndim)]) + + +@constexpr +def _get_flip_end(ndim, shape, axes): + return tuple([-shape[i] - 1 if i in axes else shape[i] + 1 for i in range(ndim)]) + + +@constexpr +def _get_flip_strides(ndim, axes): + return tuple([-1 if i in axes else 1 for i in range(ndim)]) + + +def flip(m, axis=None): + """ + Reverses the order of elements in an array along the given axis. + + The shape of the array is preserved, but the elements are reordered. + + Note: + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. + + Args: + m (Tensor): Input array. + axis (None or int or tuple of ints, optional): Axis or axes along which + to flip over. The default, ``axis=None``, will flip over all of the axes + of the input array. If `axis` is negative it counts from the last to + the first axis. If `axis` is a tuple of ints, flipping is performed on + all of the axes specified in the tuple. + + Returns: + Tensor, a view of `m` with the entries of `axis` reversed. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Example: + >>> A = np.arange(8.0).reshape((2,2,2)) + >>> output = np.flip(A) + >>> print(output) + [[[7, 6], + [5, 4]], + [[3, 2], + [1, 0]]] + >>> output = np.flip(A, (0, 2)) + >>> print(output) + [[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]] + """ + _check_input_tensor(m) + ndim = F.rank(m) + axes = _check_axis_valid(axis, ndim) + shape = F.shape(m) + dtype = F.dtype(m) + if _is_shape_empty(shape): + return m + if not _check_is_float(dtype): + m = m.astype(mstype.float32) + start = _get_flip_start(ndim, shape, axes) + end = _get_flip_end(ndim, shape, axes) + strides = _get_flip_strides(ndim, axes) + res = F.strided_slice(m, start, end, strides) + if not _check_same_type(F.dtype(res), dtype): + res = F.cast(res, dtype) + return res + + +def flipud(m): + """ + Flips the entries in each column in the up/down direction. + Rows are preserved, but appear in a different order than before. + + Note: + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. + + Args: + m (Tensor): Input array. + + Returns: + Tensor. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Example: + >>> A = np.arange(8.0).reshape((2,2,2)) + >>> output = np.flipud(A) + >>> print(output) + [[[4., 5.], + [6., 7.]], + [[0., 1.], + [2., 3.]]] + """ + return flip(m, 0) + + +def fliplr(m): + """ + Flip the entries in each row in the left/right direction. + Columns are preserved, but appear in a different order than before. + + Note: + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. + + Args: + m (Tensor): Input array. + + Returns: + Tensor. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Example: + >>> A = np.arange(8.0).reshape((2,2,2)) + >>> output = np.fliplr(A) + >>> print(output) + [[[2., 3.], + [0., 1.]], + [[6., 7.], + [4., 5.]]] + """ + return flip(m, 1) + + +def take_along_axis(arr, indices, axis): + """ + Takes values from the input array by matching 1d index and data slices. + + This iterates over matching 1d slices oriented along the specified axis in the + index and data arrays, and uses the former to look up values in the latter. + These slices can be different lengths. + + Args: + arr (Tensor): Source array with shape `(Ni…, M, Nk…)`. + indices (Tensor): Indices with shape `(Ni…, J, Nk…)` to take along each 1d + slice of `arr`. This must match the dimension of `arr`, but dimensions `Ni` + and `Nj` only need to broadcast against `arr`. + axis (int): The axis to take 1d slices along. If `axis` is None, the input + array is treated as if it had first been flattened to 1d. + + Returns: + Tensor, the indexed result, with shape `(Ni…, J, Nk…)`. + + Raises: + ValueError: if input array and indices have different number of dimensions. + TypeError: if the input is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Example: + >>> x = np.arange(12).reshape(3, 4) + >>> indices = np.arange(3).reshape(1, 3) + >>> output = np.take_along_axis(x, indices, 1) + >>> print(output) + [[ 0 1 2] + [ 4 5 6] + [ 8 9 10]] + """ + _check_input_tensor(arr, indices) + if axis is None: + arr = ravel(arr) + axis = 0 + ndim = F.rank(arr) + if ndim != F.rank(indices): + _raise_value_error('`indices` and `arr` must have the same number of dimensions') + _check_axis_in_range(axis, ndim) + axis = axis + ndim if axis < 0 else axis + + shape_arr = F.shape(arr) + shape_indices = F.shape(indices) + # broadcasts indices against the shape of arr except at axis + indices = _broadcast_to(indices, _tuple_getitem(shape_indices, axis, False), + _tuple_getitem(shape_arr, axis, False), ndim) + indices = _broadcast_to(indices, _tuple_getitem(shape_arr, axis + 1, False) + + _tuple_getitem(shape_indices, axis + 1), shape_arr, ndim) + return F.gather_d(arr, axis, indices) + + +def _mod(x, y): + """Computes x mod y.""" + quotient = F.tensor_floordiv(x, y) + prod = F.tensor_mul(y, quotient) + return F.tensor_sub(x, prod) + + +def _check_indices(size, indices, mode): + """Checks whether indices are out of bounds.""" + shape = F.shape(indices) + dtype = F.dtype(indices) + lowerbounds = F.fill(dtype, shape, -size) + upperbounds = F.fill(dtype, shape, size - 1) + out_of_lowerbounds = F.tensor_lt(indices, lowerbounds) + out_of_upperbounds = F.tensor_gt(indices, upperbounds) + if mode == 'raise': + # For mode raise, index-out-of-bounds checking is performed at backend since + # evaluation of a boolean scalar Tensor always returns true in graph mode + # regardless of the truth value contained + return indices + if mode == 'wrap': + return _mod(indices, F.fill(dtype, shape, size)) + zeros = F.fill(dtype, shape, 0) + clipped = F.select(out_of_lowerbounds, zeros, indices) + clipped = F.select(out_of_upperbounds, upperbounds, clipped) + return clipped + + +def take(a, indices, axis=None, mode='raise'): + """ + Takes elements from an array along an axis. + + When axis is not None, this function does the same thing as “fancy” indexing + (indexing arrays using arrays); however, it can be easier to use if you need + elements along a given axis. A call such as ``np.take(arr, indices, axis=3)`` is + equivalent to ``arr[:,:,:,indices,...]``. + + Note: + Numpy argument out is not supported. + + Args: + a (Tensor): Source array with shape `(Ni…, M, Nk…)`. + indices (Tensor): The indices with shape `(Nj...)` of the values to extract. + axis (int, optional): The axis over which to select values. By default, + the flattened input array is used. + mode (‘raise’, ‘wrap’, ‘clip’, optional): Specifies how out-of-bounds + indices will behave. + + ‘raise’ – raise an error (default); + + ‘wrap’ – wrap around; + + ‘clip’ – clip to the range. ‘clip’ mode means that all indices that are + too large are replaced by the index that addresses the last element + along that axis. Note that this disables indexing with negative numbers. + + Returns: + Tensor, the indexed result. + + Raises: + ValueError: if axis is out of range. + TypeError: if the input is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> output = np.take(a, indices) + >>> print(output) + [4 3 6] + >>> indices = np.array([[0, 1], [2, 3]]) + >>> output = np.take(a, indices) + >>> print(output) + [[4 3] + [5 7]] + """ + _check_input_tensor(a, indices) + if axis is None: + a = ravel(a) + axis = 0 + ndim = F.rank(a) + _check_axis_in_range(axis, ndim) + axis = axis + ndim if axis < 0 else axis + + shape_a = F.shape(a) + shape_indices = F.shape(indices) + size_indices = indices.size + indices = _check_indices(shape_a[axis], indices, mode) + + # reshapes indices to shape (Ni..., Nj..., Nk) + shape_ni = _tuple_getitem(shape_a, axis, False) + shape_nk = _tuple_getitem(shape_a, axis + 1) + shape_out = shape_ni + shape_indices + shape_nk + shape_indices = _expanded_shape(ndim, size_indices, axis) + indices = F.reshape(indices, shape_indices) + shape_indices = shape_ni + (indices.size,) + shape_nk + indices = _broadcast_to_shape(indices, shape_indices) + + res = F.gather_d(a, axis, indices) + return F.reshape(res, shape_out) + + +def repeat(a, repeats, axis=None): + """ + Repeats elements of an array. + + Args: + a (Tensor): Input array. + repeats (int or sequence of ints): The number of repetitions for each element. + `repeats` is broadcasted to fit the shape of the given axis. + axis (int, optional): The axis along which to repeat values. By default, + use the flattened input array, and return a flat output array. + + Returns: + Tensor, output array which has the same shape as `a`, except along the given + axis. + + Raises: + ValueError: if axis is out of range. + TypeError: if input `a` is not a Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.repeat(np.array(3), 4) + >>> print(output) + [3 3 3 3] + >>> x = np.array([[1,2],[3,4]]) + >>> output = np.repeat(x, 2) + >>> print(output) + [1 1 2 2 3 3 4 4] + >>> output = np.repeat(x, 3, axis=1) + >>> print(output) + [[1 1 1 2 2 2] + [3 3 3 4 4 4]] + >>> output = np.repeat(x, [1, 2], axis=0) + >>> print(output) + [[1 2] + [3 4] + [3 4]] + """ + _check_input_tensor(a) + if axis is None: + a = ravel(a) + axis = 0 + ndim = F.rank(a) + _check_axis_in_range(axis, ndim) + axis = axis + ndim if axis < 0 else axis + if isinstance(repeats, (tuple, list)) and len(repeats) == 1: + repeats = repeats[0] + if isinstance(repeats, int): + if repeats == 0: + return _empty(F.dtype(a), (0,)) + return C.repeat_elements(a, repeats, axis) + shape = F.shape(a) + size = shape[axis] + subs = split(a, size, axis) + repeated_subs = [] + for sub, rep in zip(subs, repeats): + if rep != 0: + repeated_subs.append(C.repeat_elements(sub, rep, axis)) + return concatenate(repeated_subs, axis) diff --git a/mindspore/numpy/dtypes.py b/mindspore/numpy/dtypes.py index acdcb45668..1c1a946a1f 100644 --- a/mindspore/numpy/dtypes.py +++ b/mindspore/numpy/dtypes.py @@ -22,7 +22,12 @@ from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, ui # backend for now. inf = float('inf') +PINF = float('inf') +NINF = float('-inf') nan = float('nan') +# all three of inf, PINF, and NINF are defined in the original numpy, and as we aim for +# consistency same thing is done here +pi = 3.141592653589793 int_ = int32 uint = uint32 diff --git a/mindspore/numpy/logic_ops.py b/mindspore/numpy/logic_ops.py new file mode 100644 index 0000000000..95f2c0ac92 --- /dev/null +++ b/mindspore/numpy/logic_ops.py @@ -0,0 +1,576 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""logical operations, the function docs are adapted from Numpy API.""" + + +from .math_ops import _apply_tensor_op +from ..ops import functional as F +from ..common import dtype as mstype +from .._c_expression import typing + +from .array_creations import zeros, ones +from .utils import _check_input_tensor + + +def not_equal(x1, x2, out=None, where=True, dtype=None): + """ + Returns (x1 != x2) element-wise. + + Args: + x1 (Tensor): First input tensor to be compared. + x2 (Tensor): Second input tensor to be compared. + out (Tensor or None, optional), default is None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless `dtype` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: If the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, 2]) + >>> b = np.asarray([[1, 3],[1, 4]]) + >>> print(np.not_equal(a, b)) + >>> [[False True] + [False True]] + """ + _check_input_tensor(x1, x2) + return _apply_tensor_op(F.not_equal, x1, x2, out=out, where=where, dtype=dtype) + + +def less_equal(x1, x2, out=None, where=True, dtype=None): + """ + Returns the truth value of ``(x1 <= x2)`` element-wise. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): Input array. + x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.less_equal(np.array([4, 2, 1]), np.array([2, 2, 2])) + >>> print(output) + [False True True] + """ + _check_input_tensor(x1, x2) + return _apply_tensor_op(F.tensor_le, x1, x2, out=out, where=where, dtype=dtype) + + +def less(x1, x2, out=None, where=True, dtype=None): + """ + Returns the truth value of ``(x1 < x2)`` element-wise. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): input array. + x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.less(np.array([1, 2]), np.array([2, 2])) + >>> print(output) + [ True False] + """ + return _apply_tensor_op(F.tensor_lt, x1, x2, out=out, where=where, dtype=dtype) + + +def greater_equal(x1, x2, out=None, where=True, dtype=None): + """ + Returns the truth value of ``(x1 >= x2)`` element-wise. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): Input array. + x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.greater_equal(np.array([4, 2, 1]), np.array([2, 2, 2])) + >>> print(output) + [ True True False] + """ + return _apply_tensor_op(F.tensor_ge, x1, x2, out=out, where=where, dtype=dtype) + + +def greater(x1, x2, out=None, where=True, dtype=None): + """ + Returns the truth value of ``(x1 > x2)`` element-wise. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): Input array. + x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.greater(np.array([4, 2]), np.array([2, 2])) + >>> print(output) + [ True False] + """ + return _apply_tensor_op(F.tensor_gt, x1, x2, out=out, where=where, dtype=dtype) + + +def equal(x1, x2, out=None, where=True, dtype=None): + """ + Returns the truth value of ``(x1 == x2)`` element-wise. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): Input array. + x2 (Tensor): Input array. If ``x1.shape != x2.shape``, they must be + broadcastable to a common shape (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise comparison of `x1` and `x2`. Typically of type + bool, unless ``dtype=object`` is passed. This is a scalar if both `x1` and `x2` are + scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.equal(np.array([0, 1, 3]), np.arange(3)) + >>> print(output) + [ True True False] + """ + return _apply_tensor_op(F.equal, x1, x2, out=out, where=where, dtype=dtype) + + +def isfinite(x, out=None, where=True, dtype=None): + """ + Tests element-wise for finiteness (not infinity or not Not a Number). + + The result is returned as a boolean array. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): Input values. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, true where `x` is not positive infinity, negative infinity, + or NaN; false otherwise. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.isfinite(np.array([np.inf, 1., np.nan]).astype('float32')) + >>> print(output) + [False True False] + >>> output = np.isfinite(np.log(np.array(-1.).astype('float32'))) + >>> print(output) + False + """ + return _apply_tensor_op(F.isfinite, x, out=out, where=where, dtype=dtype) + + +def _isnan(x): + """Compures isnan without applying keyword arguments.""" + shape = F.shape(x) + zeros_tensor = zeros(shape, mstype.float32) + ones_tensor = ones(shape, mstype.float32) + non_neg = F.tensor_ge(x, zeros_tensor) + non_pos = F.tensor_le(x, zeros_tensor) + res = F.select(non_neg, zeros_tensor, ones_tensor) + res = F.select(non_pos, zeros_tensor, res) + return F.cast(res, mstype.bool_) + + +def isnan(x, out=None, where=True, dtype=None): + """ + Tests element-wise for NaN and return result as a boolean array. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): Input values. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, true where `x` is NaN, false otherwise. This is a scalar if + `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.isnan(np.array(np.nan, np.float32)) + >>> print(output) + True + >>> output = np.isnan(np.array(np.inf, np.float32)) + >>> print(output) + False + """ + return _apply_tensor_op(_isnan, x, out=out, where=where, dtype=dtype) + + +def _isinf(x): + """Computes isinf without applying keyword arguments.""" + shape = F.shape(x) + zeros_tensor = zeros(shape, mstype.float32) + ones_tensor = ones(shape, mstype.float32) + not_inf = F.isfinite(x) + is_nan = _isnan(x) + res = F.select(not_inf, zeros_tensor, ones_tensor) + res = F.select(is_nan, zeros_tensor, res) + return F.cast(res, mstype.bool_) + + +def isinf(x, out=None, where=True, dtype=None): + """ + Tests element-wise for positive or negative infinity. + + Returns a boolean array of the same shape as `x`, True where ``x == +/-inf``, otherwise False. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): Input values. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, true where `x` is positive or negative infinity, false + otherwise. This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.isinf(np.array(np.inf, np.float32)) + >>> print(output) + True + >>> output = np.isinf(np.array([np.inf, -np.inf, 1.0, np.nan], np.float32)) + >>> print(output) + [ True True False False] + """ + return _apply_tensor_op(_isinf, x, out=out, where=where, dtype=dtype) + + +def _is_sign_inf(x, fn): + """Tests element-wise for inifinity with sign.""" + shape = F.shape(x) + zeros_tensor = zeros(shape, mstype.float32) + ones_tensor = ones(shape, mstype.float32) + not_inf = F.isfinite(x) + is_sign = fn(x, zeros_tensor) + res = F.select(not_inf, zeros_tensor, ones_tensor) + res = F.select(is_sign, res, zeros_tensor) + return F.cast(res, mstype.bool_) + + +def isposinf(x): + """ + Tests element-wise for positive infinity, returns result as bool array. + + Note: + Numpy argument `out` is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): Input values. + + Returns: + Tensor or scalar, true where `x` is positive infinity, false otherwise. + This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.isposinf(np.array([-np.inf, 0., np.inf], np.float32)) + >>> print(output) + [False False True] + """ + _check_input_tensor(x) + return _is_sign_inf(x, F.tensor_gt) + + +def isneginf(x): + """ + Tests element-wise for negative infinity, returns result as bool array. + + Note: + Numpy argument `out` is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): Input values. + + Returns: + Tensor or scalar, true where `x` is negative infinity, false otherwise. + This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.isneginf(np.array([-np.inf, 0., np.inf], np.float32)) + >>> print(output) + [ True False False] + """ + return _is_sign_inf(x, F.tensor_lt) + + +def isscalar(element): + """ + Returns True if the type of element is a scalar type. + + Note: + Only object types recognized by the mindspore parser are supported, + which includes objects, types, methods and functions defined within + the scope of mindspore. Other built-in types are not supported. + + Args: + element (any): Input argument, can be of any type and shape. + + Returns: + Boolean, True if `element` is a scalar type, False if it is not. + + Raises: + TypeError: if the type of `element` is not supported by mindspore parser. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.isscalar(3.1) + >>> print(output) + True + >>> output = np.isscalar(np.array(3.1)) + >>> print(output) + False + >>> output = np.isscalar(False) + >>> print(output) + True + >>> output = np.isscalar('numpy') + >>> print(output) + True + """ + return isinstance(F.typeof(element), (typing.Number, typing.Int, typing.UInt, + typing.Float, typing.Bool, typing.String)) diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 721ba48522..6ffe003c48 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -13,53 +13,69 @@ # limitations under the License. # ============================================================================ """math operations, the function docs are adapted from Numpy API.""" +import operator +import functools + from ..ops import operations as P from ..ops import functional as F from ..ops import composite as C from ..ops.primitive import constexpr from ..common import dtype as mstype -from .array_ops import ravel +from ..common import Tensor + +from .dtypes import nan, pi + +from .array_creations import asarray_const, ones, zeros, empty, full from .array_ops import where as where_ -from .array_creations import asarray, full -from .utils import _is_scalar, _expand, _broadcast_to, _is_empty -from .utils_const import _infer_out_shape, _check_axis_valid, _get_device_compile, \ - _check_shape_aligned, _empty, _check_is_tensor, _raise_type_error, _check_same_type, \ - _check_is_float, _check_input_tensor -from .dtypes import nan +from .array_ops import ravel, expand_dims + +from .utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ + _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ + _raise_value_error, _check_matmul_shapes, _promote, _check_axis_type, _canonicalize_axis, \ + _max, _is_shape_empty, _check_is_int +from .utils import _is_scalar, _expand, _broadcast_to, _broadcast_to_shape, _get_size, \ + _check_input_tensor + + +ZERO_TENSOR = asarray_const(0) _mean_default = P.ReduceMean() _mean_keepdims = P.ReduceMean(True) _matmul = P.MatMul(False, False) _matmul_T = P.MatMul(False, True) - +_reduce_sum_default = P.ReduceSum() +_reduce_sum_keepdims = P.ReduceSum(True) +_reduce_min_default = P.ReduceMin() +_reduce_min_keepdims = P.ReduceMin(True) +_reduce_max_default = P.ReduceMax() +_reduce_max_keepdims = P.ReduceMax(True) def absolute(x, out=None, where=True, dtype=None): """ Calculates the absolute value element-wise. Note: - Numpy arguments casting, order, dtype, subok, signature, and extobj are + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are not supported. - When argument where is provided, argument out must have a tensor value. - Argument out is not supported for storing the result, however it can be - used in combination with argument where to set the value at indices for - which where is set to False. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. Currently the backend kernel only supports float calculation, if the input - is not a float, then it will be casted to float32 and casted back. + is not a `float`, then it will be casted to :class:`mstype.float32` and casted back. Args: x (Tensor): Tensor to be used for calculation. - out (Tensor or None): optional, defaults to None. - where (Tensor or None): optional. For any non-default value of type other - than Tensor or None, the output retains its original value. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. This condition is broadcasted over the input. At locations where the - condition is True, the out array will be set to the ufunc result. + condition is `True`, the out array will be set to the ufunc result. Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain uninitialized. - dtype (data type): optional, defaults to None. Overrides the dtype of the + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: @@ -78,8 +94,6 @@ def absolute(x, out=None, where=True, dtype=None): >>> print(output) [1. 2. 3. 4. 5.] """ - if not _check_is_tensor(F.typeof(x)): - _raise_type_error("Input is expected to be a tensor, but got ", x) original_dtype = x.dtype if not _check_is_float(original_dtype) and dtype is None: x = x.astype(mstype.float32) @@ -87,36 +101,206 @@ def absolute(x, out=None, where=True, dtype=None): return _apply_tensor_op(F.absolute, x, out=out, where=where, dtype=dtype) +def count_nonzero(x, axis=None, keepdims=False): + """ + Counts the number of non-zero values in the tensor `x`. + + Args: + x (Tensor): The tensor for which to count non-zeros. + axis (Union[int,tuple], optional): Axis or tuple of axes along which to + count non-zeros. Default is None, meaning that non-zeros will be counted + along a flattened version of `x`. + keepdims (bool, optional): If this is set to True, the axes that are counted + are left in the result as dimensions with size one. With this option, + the result will broadcast correctly against `x`. + + Returns: + Tensor, indicating number of non-zero values in the `x` along a given axis. + Otherwise, the total number of non-zero values in `x` is returned. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.asarray([1, 2, 3, -4, 0, 3, 2, 0]) + >>> output = np.count_nonzero(x) + >>> print(output) + 6 + """ + if _is_shape_empty(x.shape): + return ZERO_TENSOR + if axis is None: + axis = () + return C.count_nonzero(x=x, axis=axis, keep_dims=keepdims) + + +def clip(x, xmin, xmax, out=None, where=True, dtype=None): + """ + Clips (limits) the values in an array. + + Given an interval, values outside the interval are clipped to the interval edges. + For example, if an interval of :math:`[0, 1]` is specified, values smaller than 0 become 0, + and values larger than 1 become 1. + + Args: + x (Tensor): Tensor containing elements to clip. + xmin (Tensor, scalar, None): Minimum value. If None, clipping is not performed + on lower interval edge. Not more than one of `xmin` and `xmax` may be None. + xmax (Tensor, scalar, None): Maximum value. If None, clipping is not performed + on upper interval edge. Not more than one of `xmin` and `xmax` may be None. + If `xmin` or `xmax` are tensors, then the three tensors will be broadcasted + to match their shapes. + out (Tensor or None): optional, default to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, a tensor with the elements of `x`, but where values + < `xmin` are replaced with `xmin`, and those > `xmax` with `xmax`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.asarray([1, 2, 3, -4, 0, 3, 2, 0]) + >>> output = np.clip(x, 0, 2) + >>> print(output) + [1 2 2 0 0 2 2 0] + """ + if xmin is None and xmax is None: + _raise_value_error("One of max or min must be given.") + if xmin is not None: + x = maximum(x, xmin, out=out, where=where, dtype=dtype) + if xmax is not None: + x = minimum(x, xmax, out=out, where=where, dtype=dtype) + return x + + +def deg2rad(x, out=None, where=True, dtype=None): + """ + Converts angles from degrees to radians. + + Args: + x (Tensor): Angles in degrees. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, the corresponding angle in radians. This is a tensor scalar if `x` + is a tensor scalar. + + Raises: + TypeError: if `x` is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.asarray([1, 2, 3, -4, -5]) + >>> output = np.deg2rad(x) + >>> print(output) + [ 0.01745329 0.03490658 0.05235988 -0.06981317 -0.08726647] + """ + _check_input_tensor(x) + + def convert(a): + return a * pi / 180.0 + return _apply_tensor_op(convert, x, out=out, where=where, dtype=dtype) + + +def rad2deg(x, out=None, where=True, dtype=None): + """ + Converts angles from radians to degrees. + + Args: + x (Tensor): Angles in radians. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, the corresponding angle in degrees. This is a tensor scalar if `x` + is a tensor scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.asarray([1, 2, 3, -4, -5]) + >>> output = np.rad2deg(x) + >>> print(output) + [ 57.295776 114.59155 171.88733 -229.1831 -286.47888 ] + """ + _check_input_tensor(x) + + def convert(a): + return a * 180.0 / pi + return _apply_tensor_op(convert, x, out=out, where=where, dtype=dtype) + + def add(x1, x2, out=None, where=True, dtype=None): """ Adds arguments element-wise. Note: - Numpy arguments casting, order, dtype, subok, signature, and extobj are + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are not supported. - When argument where is provided, argument out must have a tensor value. - Argument out is not supported for storing the result, however it can be - used in combination with argument where to set the value at indices for - which where is set to False. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. Args: x1 (Tensor): input to be added. x2 (Tensor): input to be added. - out (Tensor or None): optional, defaults to None. - where (Tensor or None): optional. For any non-default value of type other - than Tensor or None, the output retains its original value. - This condition is broadcast over the input. At locations where the - condition is True, the out array will be set to the ufunc result. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain uninitialized. - dtype (data type): optional, defaults to None. Overrides the dtype of the + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor or scalar, the sum of x1 and x2, element-wise. This is a scalar - if both x1 and x2 are scalars. + Tensor or scalar, the sum of `x1` and `x2`, element-wise. This is a scalar + if both `x1` and `x2` are scalars. Raises: TypeError: if the input is not a tensor. @@ -135,7 +319,8 @@ def add(x1, x2, out=None, where=True, dtype=None): """ # broadcast is not fully supported in tensor_add on CPU, # so we use tensor_sub as a substitute solution - if _get_device_compile() == 'CPU': + if _get_device() == 'CPU': + _check_input_tensor(x1, x2) return subtract(x1, F.neg_tensor(x2), out=out, where=where, dtype=dtype) return _apply_tensor_op(F.tensor_add, x1, x2, out=out, where=where, dtype=dtype) @@ -145,31 +330,30 @@ def subtract(x1, x2, out=None, where=True, dtype=None): Subtracts arguments, element-wise. Note: - Numpy arguments casting, order, dtype, subok, signature, and extobj are + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are not supported. - When argument where is provided, argument out must have a tensor value. - Argument out is not supported for storing the result, however it can be - used in combination with argument where to set the value at indices for - which where is set to False. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. Args: x1 (Tensor): the input to be subtracted from. x2 (Tensor): the input to be subtracted by. - out (Tensor or None): optional, defaults to None. - where (Tensor or None): optional. For any non-default value of type other - than Tensor or None, the output retains its original value. - This condition is broadcast over the input. At locations where the - condition is True, the out array will be set to the ufunc result. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain uninitialized. - dtype (data type): optional, defaults to None. Overrides the dtype of the + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor or scalar, the difference of x1 and x2, element-wise. This is a - scalar if both x1 and x2 are scalars. + Tensor or scalar, the difference of `x1` and `x2`, element-wise. This is a + scalar if both `x1` and `x2` are scalars. Raises: TypeError: if the input is not a tensor. @@ -194,31 +378,30 @@ def multiply(x1, x2, out=None, where=True, dtype=None): Multiplies arguments element-wise. Note: - Numpy arguments casting, order, dtype, subok, signature, and extobj are + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are not supported. - When argument where is provided, argument out must have a tensor value. - Argument out is not supported for storing the result, however it can be - used in combination with argument where to set the value at indices for - which where is set to False. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. Args: x1 (Tensor): input tensor to be multiplied. x2 (Tensor): input tensor to be multiplied. - out (Tensor or None): optional, defaults to None. - where (Tensor or None): optional. For any non-default value of type other - than Tensor or None, the output retains its original value. - This condition is broadcast over the input. At locations where the - condition is True, the out array will be set to the ufunc result. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain uninitialized. - dtype (data type): optional, defaults to None. Overrides the dtype of the + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor or scalar, the product of x1 and x2, element-wise. This is a scalar - if both x1 and x2 are scalars. + Tensor or scalar, the product of `x1` and `x2`, element-wise. This is a scalar + if both `x1` and `x2` are scalars. Raises: TypeError: if the input is not a tensor. @@ -235,15 +418,13 @@ def multiply(x1, x2, out=None, where=True, dtype=None): [3, 8], [3, 8]] """ - if _get_device_compile() == 'CPU': + if _get_device() == 'CPU': + _check_input_tensor(x1, x2) # broadcast is not fully supported on CPU backend, # and explicit broadcasting is performed shape_out = _infer_out_shape(F.shape(x1), F.shape(x2)) - ndim_out = F.tuple_len(shape_out) - x1 = _expand(x1, ndim_out) - x2 = _expand(x2, ndim_out) - x1 = _broadcast_to(x1, F.shape(x1), shape_out, ndim_out) - x2 = _broadcast_to(x2, F.shape(x2), shape_out, ndim_out) + x1 = _broadcast_to_shape(x1, shape_out) + x2 = _broadcast_to_shape(x2, shape_out) return _apply_tensor_op(F.tensor_mul, x1, x2, out=out, where=where, dtype=dtype) @@ -255,31 +436,29 @@ def divide(x1, x2, out=None, where=True, dtype=None): division. Note: - Numpy arguments casting, order, dtype, subok, signature, and extobj are + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are not supported. - When argument where is provided, argument out must have a tensor value. - Argument out is not supported for storing the result, however it can be - used in combination with argument where to set the value at indices for - which where is set to False. - On GPU, the supported dtypes are np.float16, and np.float32. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. Args: x1 (Tensor): the divident. x2 (Tensor): the divisor. - out (Tensor or None): optional, defaults to None. - where (Tensor or None): optional. For any non-default value of type other - than Tensor or None, the output retains its original value. - This condition is broadcast over the input. At locations where the - condition is True, the out array will be set to the ufunc result. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain uninitialized. - dtype (data type): optional, defaults to None. Overrides the dtype of the + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor or scalar, this is a scalar if both x1 and x2 are scalars. + Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars. Raises: TypeError: if the input is not a tensor. @@ -302,35 +481,87 @@ def divide(x1, x2, out=None, where=True, dtype=None): return _apply_tensor_op(F.tensor_div, x1, x2, out=out, where=where, dtype=dtype) +def true_divide(x1, x2, out=None, where=True, dtype=None): + """ + Returns a true division of the inputs, element-wise. + + Instead of the Python traditional ‘floor division’, this returns a true + division. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): the divident. + x2 (Tensor): the divisor. + out (Tensor or None, optional) + where (Tensor, optional): + This condition is broadcast over the input. At locations where the + condition is True, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default out=None, + locations within it where the condition is False will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, this is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.full((3, 2), [1, 2]) + >>> x2 = np.full((3, 2), [3, 4]) + >>> output = np.true_divide(x1, x2) + >>> print(output) + [[0.33333333, 0.5], + [0.33333333, 0.5], + [0.33333333, 0.5]] + """ + return divide(x1, x2, out=out, where=where, dtype=dtype) + + def power(x1, x2, out=None, where=True, dtype=None): """ First array elements raised to powers from second array, element-wise. - Raises each base in x1 to the positionally-corresponding power in x2. + Raises each base in `x1` to the positionally-corresponding power in `x2`. Note: - Numpy arguments casting, order, dtype, subok, signature, and extobj are + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. On GPU, the supported dtypes are np.float16, and np.float32. Args: x1 (Tensor): the bases. x2 (Tensor): the exponents. - out (Tensor or None): optional, defaults to None. - where (Tensor or None): optional. For any non-default value of type other - than Tensor or None, the output retains its original value. - This condition is broadcast over the input. At locations where the - condition is True, the out array will be set to the ufunc result. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. Elsewhere, the out array will retain its original value. Note that - if an uninitialized out array is created via the default out=None, - locations within it where the condition is False will remain + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain uninitialized. - dtype (data type): optional, defaults to None. Overrides the dtype of the + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. Returns: - Tensor or scalar, the bases in x1 raised to the exponents in x2. This - is a scalar if both x1 and x2 are scalars. + Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This + is a scalar if both `x1` and `x2` are scalars. Raises: TypeError: if the input is not a tensor. @@ -350,7 +581,133 @@ def power(x1, x2, out=None, where=True, dtype=None): return _apply_tensor_op(F.tensor_pow, x1, x2, out=out, where=where, dtype=dtype) -def mean(a, axis=None, keepdims=False): +def float_power(x1, x2, out=None, where=True, dtype=None): + """ + First array elements raised to powers from second array, element-wise. + + Raise each base in `x1` to the positionally-corresponding power in `x2`. `x1` and + `x2` must be broadcastable to the same shape. This differs from the power + function in that integers, float16, and float64 are promoted to floats with + a minimum precision of float32 so that the result is always inexact. The + intent is that the function will return a usable result for negative powers + and seldom overflow for positive powers. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + Integers and floats are promoted to float32 instead of float64. + + Args: + x1 (Tensor): the bases. + x2 (Tensor): the exponenets. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the bases in `x1` raised to the exponents in `x2`. This + is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.arange(6) + >>> x2 = np.array(3) + >>> output = np.float_power(x1, x2) + >>> print(output) + [ 0. 1. 8. 27. 64. 125.] + """ + if not _check_same_type(F.dtype(x1), mstype.float32): + x1 = F.cast(x1, mstype.float32) + if not _check_same_type(F.dtype(x2), mstype.float32): + x2 = F.cast(x2, mstype.float32) + + return _apply_tensor_op(F.tensor_pow, x1, x2, out=out, where=where, dtype=dtype) + + +def minimum(x1, x2, out=None, where=True, dtype=None): + """ + Element-wise minimum of tensor elements. + + Compares two tensors and returns a new tensor containing the element-wise minima. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + Unlike numpy, when one of the elements is a NaN, the second element is + always returned regardless of whether the second element is a NaN, instead + of returning NaN. + + Args: + x1 (Tensor): first input tensor to be compared. + x2 (Tensor): second input tensor to be compared. + out (Tensor or None, optional), default is None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor, element-wise minimum of `x1` and `x2`. + + Raises: + TypeError: If inputs have types not specified above. + ValueError: If the shapes of `x1` and `x2` cannot be broadcast. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, 2]) + >>> b = np.asarray([[1, 3],[1, 4]]) + >>> print(np.minimum(a, b)) + [[1 2] + [1 2]] + """ + if isinstance(x1, (int, float, bool, list, tuple, Tensor)) and \ + isinstance(x2, (int, float, bool, list, tuple, Tensor)): + x1 = asarray_const(x1) + x2 = asarray_const(x2) + else: + _raise_type_error("Input x1 and x2 are expected to be array_like") + # if both are scalars, expand x1 to 1d tensor, since cpu kernel doesn't support + # comparisons with 2 scalars + if x1.ndim == 0 and x2.ndim == 0: + x1 = expand_dims(x1, 0) + return _apply_tensor_op(F.minimum, x1, x2, out=out, where=where, dtype=dtype).squeeze() + if x1.ndim == 0: + dtype = x2.dtype + elif x2.ndim == 0: + dtype = x1.dtype + return _apply_tensor_op(F.minimum, x1, x2, out=out, where=where, dtype=dtype) + + +def mean(a, axis=None, keepdims=False, dtype=None): """ Computes the arithmetic mean along the specified axis. @@ -359,29 +716,29 @@ def mean(a, axis=None, keepdims=False): axis. Note: - Numpy arguments dtype and out are not supported. + Numpy arguments `out` is not supported. On GPU, the supported dtypes are np.float16, and np.float32. - On CPU, the supported dtypes are np.float16, np.float32, and - np.float64. Args: a (Tensor): input tensor containing numbers whose mean is desired. If a is not an array, a conversion is attempted. - axis (None or int or tuple of ints): optional. Axis or axes along + axis (None or int or tuple of ints, optional): Axis or axes along which the means are computed. The default is to compute the mean of the flattened array. If this is a tuple of ints, a mean is performed over multiple axes. - keepdims(bool): optional. If this is set to True, the axes which + keepdims (bool, optional): If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input tensor. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. Returns: Tensor or scalar, an array containing the mean values. Raises: - ValueError: if axes are out of the range of [-a.ndim, a.ndim), or - if the axes contain duplicates. + ValueError: if axes are out of the range of ``[-a.ndim, a.ndim)``, or + if the axes contain duplicates. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -396,15 +753,17 @@ def mean(a, axis=None, keepdims=False): axis = _check_axis_valid(axis, F.rank(a)) shape_a = F.shape(a) + if dtype is None: + dtype = F.dtype(a) - if _is_empty(shape_a): + if _is_shape_empty(shape_a): if keepdims: shape_out = _shape_reduced_keepdims(shape_a, axis) else: shape_out = _shape_reduced(shape_a, axis) - if _is_empty(shape_out): - return _empty(F.dtype(a), shape_out) - return _full_compile(shape_out, nan) + if _is_shape_empty(shape_out): + return empty(F.dtype(a), shape_out) + return full(shape_out, nan, dtype) if _is_scalar(shape_a): if keepdims: @@ -416,37 +775,38 @@ def mean(a, axis=None, keepdims=False): res = _mean_keepdims(a, axis) else: res = _mean_default(a, axis) + if not _check_same_type(dtype, F.dtype(res)): + res = F.cast(res, dtype) return res def inner(a, b): """ - Inner product of two tensors. + Returns the inner product of two tensors. Ordinary inner product of vectors for 1-D tensors (without complex conjugation), in higher dimensions a sum product over the last axes. Note: - Numpy argument out is not supported. + Numpy argument `out` is not supported. On GPU, the supported dtypes are np.float16, and np.float32. On CPU, the supported dtypes are np.float16, np.float32, and np.float64. Args: - a (Tensor): input tensor. If a and b are nonscalar, their last + a (Tensor): input tensor. If `a` and `b` are nonscalar, their last dimensions must match. - b (Tensor): input tensor. If a and b are nonscalar, their last + b (Tensor): input tensor. If `a` and `b` are nonscalar, their last dimensions must match. Returns: - Tensor or scalar, out.shape = a.shape[:-1] + b.shape[:-1]. + Tensor or scalar. Raises: - ValueError: if x1.shape[-1] != x2.shape[-1]. + ValueError: if ``x1.shape[-1] != x2.shape[-1]``. Supported Platforms: - Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: @@ -457,27 +817,19 @@ def inner(a, b): >>> print(output) [[[3. 3. 3. 3. 3. 3. 3.] [3. 3. 3. 3. 3. 3. 3.]] - [[3. 3. 3. 3. 3. 3. 3.] [3. 3. 3. 3. 3. 3. 3.]] - [[3. 3. 3. 3. 3. 3. 3.] [3. 3. 3. 3. 3. 3. 3.]] - [[3. 3. 3. 3. 3. 3. 3.] [3. 3. 3. 3. 3. 3. 3.]] - [[3. 3. 3. 3. 3. 3. 3.] [3. 3. 3. 3. 3. 3. 3.]]] """ if F.rank(a) == 0 or F.rank(b) == 0: - a = _expand(a, 1) - b = _expand(b, 1) - if F.rank(a) < F.rank(b): - a, b = b, a return F.tensor_mul(a, b) - _ = _check_shape_aligned(F.shape(a), F.shape(b)) + _check_shape_aligned(F.shape(a), F.shape(b)) aligned_shape_a = (F.shape_mul(F.shape(a)[:-1]), F.shape(a)[-1]) aligned_shape_b = (F.shape_mul(F.shape(b)[:-1]), F.shape(a)[-1]) a_aligned = F.reshape(a, aligned_shape_a) @@ -488,29 +840,23 @@ def inner(a, b): return res -@constexpr -def _nan(): - """Returns a Tensor with nan value""" - return asarray(float('nan')) - - def dot(a, b): """ - Dot product of two arrays. + Returns the dot product of two arrays. Specifically, - If both a and b are 1-D arrays, it is inner product of vectors + If both `a` and `b` are 1-D arrays, it is inner product of vectors (without complex conjugation). - If both a and b are 2-D arrays, it is matrix multiplication. - If either a or b is 0-D (scalar), it is equivalent to multiply. - If a is an N-D array and b is a 1-D array, it is a sum product - over the last axis of a and b. - If a is an N-D array and b is an M-D array (where M>=2), it is a - sum product over the last axis of a and the second-to-last axis of b: - dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) + If both `a` and `b` are 2-D arrays, it is matrix multiplication. + If either `a` or `b` is 0-D (scalar), it is equivalent to multiply. + If `a` is an `N-D` array and `b` is a 1-D array, it is a sum product + over the last axis of `a` and `b`. + If `a` is an `N-D` array and `b` is an `M-D` array (where ``M>=2``), it is a + sum product over the last axis of `a` and the second-to-last axis of `b`: + ``dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])`` Note: - Numpy argument out is not supported. + Numpy argument `out` is not supported. On GPU, the supported dtypes are np.float16, and np.float32. On CPU, the supported dtypes are np.float16, np.float32, and np.float64. @@ -520,13 +866,13 @@ def dot(a, b): b (Tensor): input tensor Returns: - Tensor or scalar, the dot product of a and b. If a and b are + Tensor or scalar, the dot product of `a` and `b`. If `a` and `b` are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned Raises: - ValueError: If the last dimension of a is not the same size - as the second-to-last dimension of b. + ValueError: If the last dimension of `a` is not the same size + as the second-to-last dimension of `b`. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -552,15 +898,18 @@ def outer(a, b): """ Computes the outer product of two vectors. - Given two vectors, a = [a0, a1, ..., aM] and b = [b0, b1, ..., bN], - the outer product [1] is: - [[a0*b0 a0*b1 ... a0*bN ] - [a1*b0 . - [ ... . - [aM*b0 aM*bN ]] + Given two vectors, ``a = [a0, a1, ..., aM]`` and ``b = [b0, b1, ..., bN]``, + the outer product is: + ``[[a0*b0 a0*b1 ... a0*bN ]`` + + ``[a1*b0 . ]`` + + ``[ ... . ]`` + + ``[aM*b0 aM*bN ]]`` Note: - Numpy argument out is not supported. + Numpy argument ``out`` is not supported. On GPU, the supported dtypes are np.float16, and np.float32. On CPU, the supported dtypes are np.float16, np.float32, and np.float64. @@ -572,7 +921,7 @@ def outer(a, b): already 1-dimensional. Returns: - Tensor or scalar, out[i, j] = a[i] * b[j]. + Tensor or scalar, ``out[i, j] = a[i] * b[j]``. Raises: TypeError: if the input is not a tensor. @@ -594,9 +943,7 @@ def outer(a, b): [6, 6, 6, 6], [6, 6, 6, 6]] """ - _check_input_tensor(F.typeof(a)) - _check_input_tensor(F.typeof(b)) - + _check_input_tensor(a, b) if F.rank(a) != 1: a = ravel(a) if F.rank(b) != 1: @@ -610,19 +957,21 @@ def tensordot(a, b, axes=2): """ Computes tensor dot product along specified axes. - Given two tensors, a and b, and an array_like object containing two array_like - objects, (a_axes, b_axes), sum the products of a’s and b’s elements (components) - over the axes specified by a_axes and b_axes. The third argument can be a single - non-negative integer_like scalar, N; if it is such, then the last N dimensions of - a and the first N dimensions of b are summed over. + Given two tensors, `a` and `b`, and an array_like object containing two array_like + objects, `(a_axes, b_axes)`, sum the products of `a`’s and `b`’s elements (components) + over the axes specified by `a_axes` and `b_axes`. The third argument can be a single + non-negative integer_like scalar, `N`; if it is such, then the last `N` dimensions of + `a` and the first `N` dimensions of `b` are summed over. Three common use cases are: - axes = 0 : tensor product - axes = 1 : tensor dot product - axes = 2 : (default) tensor double contraction - When axes is integer_like, the sequence for evaluation will be: first the -Nth - axis in a and 0th axis in b, and the -1th axis in a and Nth axis in b last. + ``axes = 0`` : tensor product + + ``axes = 1`` : tensor dot product + + ``axes = 2`` : (default) tensor double contraction + When axes is integer_like, the sequence for evaluation will be: first the `-Nth` + axis in `a` and 0th axis in `b`, and the -1th axis in `a` and `Nth` axis in `b` last. When there is more than one axis to sum over - and they are not the last (first) - axes of a (b) - the argument axes should consist of two sequences of the same + axes of `a` `(b)` - the argument axes should consist of two sequences of the same length, with the first axis to sum over given first in both sequences, the second axis second, and so forth. The shape of the result consists of the non-contracted axes of the first tensor, @@ -633,20 +982,20 @@ def tensordot(a, b, axes=2): On GPU, the supported dypes are np.float16 and np.float32. Args: - a, b (Tensor): Tensors to “dot”. - axes (int or (2,) array_like): - integer_like: If an int N, sum over the last N axes of a and the first N - axes of b in order. The sizes of the corresponding axes must match. - (2,) array_like: Or, a list of axes to be summed over, first sequence - applying to a, second to b. Both elements array_like must be of the same + a (Tensor): Tensor to "dot". + b (Tensor): Tensor to “dot”. + axes (int or sequence of ints): + + integer_like: If an int `N`, sum over the last `N` axes of `a` and the first `N` + axes of `b` in order. The sizes of the corresponding axes must match. + + sequence of ints: Or, a list of axes to be summed over, first sequence + applying to `a`, second to `b`. Both elements `array_like` must be of the same length. Returns: Tensor, or list of tensors, the tensor dot product of the input. - Raises: - TypeError: if the input is not a tensor. - Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -657,41 +1006,1232 @@ def tensordot(a, b, axes=2): >>> print(output.shape) (5, 2) """ - _check_input_tensor(F.typeof(a)) - _check_input_tensor(F.typeof(b)) - if F.rank(a)*F.rank(b) == 0 and axes == 0: return F.tensor_mul(a, b) return C.tensor_dot(a, b, axes) -@constexpr -def _full_compile(shape, value): - return full(shape, value) +def std(x, axis=None, ddof=0, keepdims=False): + """ + Computes the standard deviation along the specified axis. + The standard deviation is the square root of the average of the squared deviations + from the mean, i.e., :math:`std = sqrt(mean(abs(x - x.mean())**2))`. + Returns the standard deviation, which is computed for the flattened array by default, + otherwise over the specified axis. -@constexpr -def _shape_reduced_keepdims(shape, axes): - """ - Reduces dimensions corresponding to argument axes while - keeping the number of dimensions unchanged. - """ - ndim_out = F.tuple_len(shape) - shape_out = [1]*ndim_out - for i in range(ndim_out): - if not i in axes: - shape_out[i] = shape[i] - return tuple(shape_out) + Note: + Numpy arguments `dtype` and `out` are not supported. + Args: + x (Tensor): A Tensor to be calculated. + axis (Union[None, int, tuple(int)]): Axis or axes along which the standard + deviation is computed. Default: `None`. -@constexpr -def _shape_reduced(shape, axes): - """Removes dimensions corresponding to argument axes""" - ndim_orig = F.tuple_len(shape) - ndim_out = ndim_orig - F.tuple_len(axes) - shape_out = [0]*ndim_out - idx_out = 0 - for i in range(ndim_orig): + If `None`, compute the standard deviation of the flattened array. + ddof (int): Means Delta Degrees of Freedom. The divisor used in calculations is :math:`N - ddof`, + where :math:`N` represents the number of elements. Default: 0. + keepdims: Default: `False`. + + Returns: + Standard deviation tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([1., 2., 3., 4.]) + >>> output = np.std(input_x) + >>> print(output) + 1.118034 + """ + if _is_shape_empty(x.shape): + return full((), nan, F.dtype(x)) + + if not isinstance(ddof, int): + _raise_type_error("integer argument expected, but got ", ddof) + if axis is None: + axis = () + else: + _check_axis_type(axis, True, True, False) + axis = _canonicalize_axis(axis, x.ndim) + + x_mean = _mean_keepdims(x, axis) + x_sub = F.tensor_sub(x, x_mean) + x_pow = F.tensor_pow(x_sub, 2) + if keepdims: + x_sum = _reduce_sum_keepdims(x_pow, axis) + else: + x_sum = _reduce_sum_default(x_pow, axis) + + if isinstance(axis, int): + nums = x.shape[axis] + else: + nums = _get_size(x, axis) + + x_std = F.tensor_pow(F.tensor_div(x_sum, nums - ddof), 0.5) + return x_std + + +def var(x, axis=None, ddof=0, keepdims=False): + """ + Computes the variance along the specified axis. + The variance is the average of the squared deviations from the mean, i.e., + :math:`var = mean(abs(x - x.mean())**2)`. + + Returns the variance, which is computed for the flattened array by default, + otherwise over the specified axis. + + Note: + Numpy arguments `dtype` and `out` are not supported. + + Args: + x (Tensor): A Tensor to be calculated. + axis (Union[None, int, tuple(int)]): Axis or axes along which the variance is computed. + The default is to compute the variance of the flattened array. Default: `None`. + ddof (int): Means Delta Degrees of Freedom. Default: 0. + The divisor used in calculations is :math:`N - ddof`, where :math:`N` represents the number of elements. + keepdims (bool): Default: `False`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Returns: + Standard deviation tensor. + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([1., 2., 3., 4.]) + >>> output = np.var(input_x) + >>> print(output) + 1.25 + """ + if _is_shape_empty(x.shape): + return full((), nan, F.dtype(x)) + + x_std = std(x, axis, ddof, keepdims) + return F.tensor_pow(x_std, 2) + + +def ptp(x, axis=None, out=None, keepdims=False): + """ + Range of values (maximum - minimum) along an axis. + The name of the function comes from the acronym for ‘peak to peak’. + + Note: + Numpy arguments `dtype` and `out` are not supported. + + Args: + x (Tensor): Input tensor. + axis (Union[None, int, tuple(int)]): Axis or axes along which the range is computed. + The default is to compute the variance of the flattened array. Default: None. + keepdims (bool): Default is False. + + Returns: + Tensor. + + Raises: + TypeError: if inputs have types not specified above. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> x = np.array([[4.0, 9.0, 2.0, 10.0], [6.0, 9.0, 7.0, 12.0]]) + >>> print(np.ptp(x, axis=1)) + [8. 6.] + >>> print(np.ptp(x, axis=0)) + [2. 0. 5. 2.] + """ + _check_input_tensor(x) + if axis is None: + axis = () + else: + _check_axis_type(axis, True, True, False) + axis = _canonicalize_axis(axis, x.ndim) + + if keepdims: + x_min = _reduce_min_keepdims(x, axis) + x_max = _reduce_max_keepdims(x, axis) + else: + x_min = _reduce_min_default(x, axis) + x_max = _reduce_max_default(x, axis) + return F.tensor_sub(x_max, x_min) + + +def average(x, axis=None, weights=None, returned=False): + """ + Computes the weighted average along the specified axis. + + Args: + x (Tensor): A Tensor to be averaged. + axis (Union[None, int, tuple(int)]): Axis along which to average `x`. Default: `None`. + If the axis is `None`, it will average over all of the elements of the tensor `x`. + If the axis is negative, it counts from the last to the first axis. + weights (Tensor): Weights associated with the values in `x`. Default: `None`. + If `weights` is `None`, all the data in `x` are assumed to have a weight equal to one. + If `weights` is 1-D tensor, the length must be the same as the given axis. + Otherwise, `weights` should have the same shape as `x`. + returned (bool): Default: `False`. + If `True`, the tuple (average, sum_of_weights) is returned. + If `False`, only the average is returned. + + Returns: + Averaged Tensor. If returned is `True`, return tuple. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> input_x = np.array([[1., 2.], [3., 4.]]) + >>> output = np.average(input_x, axis=0, weights=input_x, returned=True) + >>> print(output) + (Tensor(shape=[2], dtype=Float32, value= [ 2.50000000e+00, 3.33333325e+00]), + Tensor(shape=[2], dtype=Float32, value= [ 4.00000000e+00, 6.00000000e+00])) + """ + if axis is None: + axis = () + else: + _check_axis_type(axis, True, True, False) + axis = _canonicalize_axis(axis, x.ndim) + + if weights is None: + return mean(x, axis) + + x_avg = full((), nan, F.dtype(x)) + sum_of_weights = None + if x.shape == weights.shape: + x_avg, sum_of_weights = comput_avg(x, axis, weights) + elif F.rank(weights) == 1: + if not isinstance(axis, int): + _raise_type_error("Axis must be specified when shapes of x and weights differ.") + weights = _broadcast_to_shape(weights, x.shape) + x_avg, sum_of_weights = comput_avg(x, axis, weights) + else: + _raise_type_error("Weights should be None, 1-D or the same as input x, but got shape of", weights) + + if returned: + return (x_avg, sum_of_weights) + return x_avg + + +def comput_avg(x, axis, weights): + """Computes average value of input x with given parameters.""" + x_mul = F.tensor_mul(x, weights) + x_sum = _reduce_sum_default(x_mul, axis) + sum_of_weights = _reduce_sum_default(weights, axis) + x_avg = F.tensor_div(x_sum, sum_of_weights) + return x_avg, sum_of_weights + + +def matmul(x1, x2, dtype=None): + """ + Returns the matrix product of two arrays. + + Note: + Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + On GPU, the supported dtypes are np.float16 and np.float32. + On CPU, the supported dtypes are np.float16 and np.float32. + + Args: + x1 (Tensor): Input tensor, scalar not allowed. + x2 (Tensor): Input tensor, scalar not allowed. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the matrix product of the inputs. This is a scalar only + when both `x1`, `x2` are 1-d vectors. + + Raises: + ValueError: If the last dimension of `x1` is not the same size as the + second-to-last dimension of `x2`, or if a scalar value is passed in. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32') + >>> x2 = np.arange(4*5).reshape(4, 5).astype('float32') + >>> output = np.matmul(x1, x2) + >>> print(output) + [[[ 70. 76. 82. 88. 94.] + [ 190. 212. 234. 256. 278.] + [ 310. 348. 386. 424. 462.]] + [[ 430. 484. 538. 592. 646.] + [ 550. 620. 690. 760. 830.] + [ 670. 756. 842. 928. 1014.]]] + """ + # performs type promotion + dtype1 = F.dtype(x1) + dtype2 = F.dtype(x2) + dtype_out = _promote(dtype1, dtype2) + if not _check_same_type(dtype1, dtype_out): + x1 = F.cast(x1, dtype_out) + if not _check_same_type(dtype2, dtype_out): + x2 = F.cast(x2, dtype_out) + + ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2) + shape1_orig, shape2_orig = F.shape(x1), F.shape(x2) + _check_matmul_shapes(shape1_orig, shape2_orig) + ndim_aligned = _max(ndim1_orig, ndim2_orig) + transpose_b = ndim2_orig == 1 + shape_backbone = _infer_out_shape( + shape1_orig[:-2], shape2_orig[:-2]) + # infers the shape of the output + shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig, + ndim1_orig, ndim2_orig, transpose_b) + + x1 = _expand(x1, _max(ndim_aligned, 2)) + x2 = _expand(x2, _max(ndim_aligned, 2)) + shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2) + + if ndim_aligned <= 2: + res = P.MatMul(False, transpose_b)(x1, x2) + else: + # broadcasts x1.shape[:-2] with x2.shape[:-2] + shape_aligned = shape_backbone + _infer_shape_rem(shape1_aligned, shape2_aligned, + ndim_aligned, ndim_aligned, + transpose_b) + x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_aligned[:-2], ndim_aligned) + x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_aligned[:-2], ndim_aligned) + res = P.BatchMatMul(False, transpose_b)(x1, x2) + + if dtype is not None and not _check_same_type(dtype_out, dtype): + res = F.cast(res, dtype) + return F.reshape(res, shape_out) + + +def square(x, out=None, where=True, dtype=None): + """ + Returns the element-wise square of the input. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16 and np.float32. + + Args: + x (Tensor): Input data. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise ``x*x``, of the same shape and dtype as `x`. + This is a scalar if `x` is a scalar.. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.square(np.arange(6).reshape(2, 3).astype('float32')) + >>> print(x) + [[ 0. 1. 4.] + [ 9. 16. 25.]] + """ + return _apply_tensor_op(F.square, x, out=out, where=where, dtype=dtype) + + +def sqrt(x, out=None, where=True, dtype=None): + """ + Returns the non-negative square-root of an array, element-wise. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16 and np.float32. + + Args: + x (Tensor): The values whose square-roots are required. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, an array of the same shape as `x`, containing the positive + square-root of each element in `x`. For negative elements, nan is returned. + This is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.arange(6).reshape(2, 3).astype('float32') + >>> x_squared = np.square(x) + >>> output = np.sqrt(x_squared) + >>> print(output) + [[ 0. 1. 2.] + [ 3. 4. 5.]] + """ + return _apply_tensor_op(F.sqrt, x, out=out, where=where, dtype=dtype) + + +def reciprocal(x, out=None, where=True, dtype=None): + """ + Returns the reciprocal of the argument, element-wise. + + Calculates ``1/x``. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x (Tensor): Input array. For integer arguments with absolute value larger + than 1 the result is always zero because of the way Python handles + integer division. For integer zero the result is an overflow. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, this is a scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.arange(1, 7).reshape(2, 3).astype('float32') + >>> output = np.reciprocal(x) + >>> print(output) + [[1. 0.5 0.33333334] + [0.25 0.2 0.16666667]] + """ + return _apply_tensor_op(lambda x: F.tensor_div(1, x), x, out=out, where=where, dtype=dtype) + + +def log(x, out=None, where=True, dtype=None): + """ + Returns the natural logarithm, element-wise. + + The natural logarithm log is the inverse of the exponential function, so that + ``log(exp(x)) = x``. The natural logarithm is logarithm in base e. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. + + Args: + x (Tensor): Input array. For integer arguments with absolute value larger + than 1 the result is always zero because of the way Python handles + integer division. For integer zero the result is an overflow. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the natural logarithm of `x`, element-wise. This is a + scalar if `x` is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> x = np.array([1, 2, 3]).astype('float32') + >>> output = np.log(x) + >>> print(output) + [1.09861 1.3862929 1.6094407] + """ + return _apply_tensor_op(F.log, x, out=out, where=where, dtype=dtype) + + +def maximum(x1, x2, out=None, where=True, dtype=None): + """ + Returns the element-wise maximum of array elements. + + Compares two arrays and returns a new array containing the element-wise maxima. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + Unlike numpy, when one of the elements is a NaN, the second element is + always returned regardless of whether the second element is a NaN, instead + of returning NaN. + + Args: + x1 (Tensor): Input array + x2 (Tensor): The array holding the elements to be compared. If + ``x1.shape != x2.shape``, they must be broadcastable to a common shape + (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the maximum of `x1` and `x2`, element-wise. This is a scalar + if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.maximum(np.array([2, 3, 4]), np.array([1, 5, 2])) + >>> print(output) + [2 5 4] + """ + if isinstance(x1, (int, float, bool, list, tuple, Tensor)) and \ + isinstance(x2, (int, float, bool, list, tuple, Tensor)): + x1 = asarray_const(x1) + x2 = asarray_const(x2) + else: + _raise_type_error("Input x1 and x2 are expected to be array_like") + # F.maximum does not support when both operands are scalar + if x1.ndim == 0 and x2.ndim == 0: + x1 = expand_dims(x1, 0) + return _apply_tensor_op(F.maximum, x1, x2, out=out, where=where, dtype=dtype).squeeze() + if x1.ndim == 0: + dtype = x2.dtype + elif x2.ndim == 0: + dtype = x1.dtype + return _apply_tensor_op(F.maximum, x1, x2, out=out, where=where, dtype=dtype) + + +def heaviside(x1, x2, out=None, where=True, dtype=None): + """ + Computes the Heaviside step function. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): Input values. + x2 (Tensor): The value of the function when `x1` is 0. If + ``x1.shape != x2.shape``, they must be broadcastable to a common shape + (which becomes the shape of the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the output array, element-wise Heaviside step function + of `x1`. This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.heaviside(np.array([-1.5, 0, 2.0]), np.array(0.5)) + >>> print(output) + [0. 0.5 1. ] + >>> output = np.heaviside(np.array([-1.5, 0, 2.0]), np.array(1)) + >>> print(output) + [0. 1. 1.] + """ + + def _heaviside(x1, x2): + """Computes heaviside without passing keyword arguments""" + # performs type promotion + dtype1 = F.dtype(x1) + dtype2 = F.dtype(x2) + dtype_out = _promote(dtype1, dtype2) + if not _check_same_type(dtype1, dtype_out): + x1 = F.cast(x1, dtype_out) + if not _check_same_type(dtype2, dtype_out): + x2 = F.cast(x2, dtype_out) + + # performs broadcast + shape_out = _infer_out_shape(F.shape(x1), F.shape(x2)) + x1 = _broadcast_to_shape(x1, shape_out) + x2 = _broadcast_to_shape(x2, shape_out) + + x2 = F.select(x1 < 0, zeros(shape_out, dtype_out), x2) + x2 = F.select(x1 > 0, ones(shape_out, dtype_out), x2) + return x2 + + return _apply_tensor_op(_heaviside, x1, x2, out=out, where=where, dtype=dtype) + + +def amax(a, axis=None, keepdims=False, initial=None, where=True): + """ + Returns the maximum of an array or maximum along an axis. + + Note: + Numpy argument `out` is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + a (Tensor): Input data. + axis (None or int or tuple of ints, optional): defaults to None. Axis or + axes along which to operate. By default, flattened input is used. If + this is a tuple of ints, the maximum is selected over multiple axes, + instead of a single axis or all the axes as before. + keepdims (boolean, optional): defaults to False. + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the input array. + initial (scalar, optional): + The minimum value of an output element. Must be present to allow + computation on empty slice. + where (boolean Tensor, optional): defaults to True. + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. If non-default value + is passed, initial must also be provided. + + Returns: + Tensor or scalar, maximum of `a`. If `axis` is None, the result is a scalar + value. If `axis` is given, the result is an array of dimension ``a.ndim - 1``. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.arange(4).reshape((2,2)).astype('float32') + >>> output = np.amax(a) + >>> print(output) + 3.0 + >>> output = np.amax(a, axis=0) + >>> print(output) + [2. 3.] + >>> output = np.amax(a, axis=1) + >>> print(output) + [1. 3.] + >>> output = np.amax(a, where=np.array([False, True]), initial=-1, axis=0) + >>> print(output) + [-1. 3.] + """ + return _reduce(a, P.ReduceMax(keepdims), F.maximum, axis=axis, keepdims=keepdims, + initial=initial, where=where) + + +def amin(a, axis=None, keepdims=False, initial=None, where=True): + """ + Returns the minimum of an array or minimum along an axis. + + Note: + Numpy argument `out` is not supported. + On GPU, the supported dtypes are np.float16, and np.float32. + + Args: + a (Tensor): Input data. + axis (None or int or tuple of ints, optional): defaults to None. Axis or + axes along which to operate. By default, flattened input is used. If + this is a tuple of ints, the maximum is selected over multiple axes, + instead of a single axis or all the axes as before. + keepdims (boolean, optional): defaults to False. + If this is set to True, the axes which are reduced are left in the + result as dimensions with size one. With this option, the result will + broadcast correctly against the input array. + initial (scalar, optional): + The maximum value of an output element. Must be present to allow + computation on empty slice. + where (boolean Tensor, optional): defaults to True. + A boolean array which is broadcasted to match the dimensions of array, + and selects elements to include in the reduction. If non-default value + is passed, initial must also be provided. + + Returns: + Tensor or scalar, minimum of `a`. If axis is None, the result is a scalar + value. If `axis` is given, the result is an array of dimension ``a.ndim - 1``. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> a = np.arange(4).reshape((2,2)).astype('float32') + >>> output = np.amin(a) + >>> print(output) + 0.0 + >>> output = np.amin(a, axis=0) + >>> print(output) + [0. 1.] + >>> output = np.amin(a, axis=1) + >>> print(output) + [1. 3.] + >>> output = np.amax(a, where=np.array([False, True]), initial=10, axis=0) + >>> print(output) + [10. 1.] + """ + return _reduce(a, P.ReduceMin(keepdims), F.minimum, axis=axis, keepdims=keepdims, + initial=initial, where=where) + + +def hypot(x1, x2, out=None, where=True, dtype=None): + """ + Given the “legs” of a right triangle, returns its hypotenuse. + + Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or `x2` is scalar_like + (i.e., unambiguously cast-able to a scalar type), it is broadcast for use + with each element of the other argument. (See Examples) + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16 and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. + + Args: + x1 (Tensor): Leg of the traingle(s). + x2 (Tensor): Leg of the triangle(s). If ``x1.shape != x2.shape``, they + must be broadcastable to a common shape (which becomes the shape of + the output). + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the hypotenuse of the triangle(s). This is a scalar if + both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3))) + >>> print(output) + [[5. 5. 5.] + [5. 5. 5.] + [5. 5. 5.]] + >>> output = np.hypot(3*np.ones((3, 3)), np.array([4])) + >>> print(output) + [[5. 5. 5.] + [5. 5. 5.] + [5. 5. 5.]] + """ + + def _hypot(x1, x2): + """Computes hypotenuse without passing keyword arguments""" + if _get_device() == 'CPU': + # broadcast is not fully supported in tensor_add on CPU, + # so we use tensor_sub as a substitute solution + return F.sqrt(F.tensor_sub(F.square(x1), F.neg_tensor(F.square(x2)))) + return F.sqrt(F.tensor_add(F.square(x1), F.square(x2))) + + return _apply_tensor_op(_hypot, x1, x2, out=out, where=where, dtype=dtype) + + +def floor(x, out=None, where=True, dtype=None): + """ + Returns the floor of the input, element-wise. + + The floor of the scalar `x` is the largest integer `i`, such that ``i <= x``. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16 and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, and np.float64. + + Args: + x (Tensor): input data. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the floor of each element in `x`. This is a scalar if `x` + is a scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.floor(np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])) + >>> print(output) + [-2. -2. -1. 0. 1. 1. 2.] + """ + return _apply_tensor_op(F.floor, x, out=out, where=where, dtype=dtype) + + +def floor_divide(x1, x2, out=None, where=True, dtype=None): + """ + Returns the largest integer smaller or equal to the division of the inputs. + It is equivalent to the Python // operator and pairs with the + Python % (remainder), function so that ``a = a % b + b * (a // b)`` up to roundoff. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): Input array. + x2 (Tensor): Input array. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.floor_divide(np.array([1., 2., 3., 4.]), np.array(2.5)) + >>> print(output) + [0. 0. 1. 1.] + """ + return _apply_tensor_op(F.tensor_floordiv, x1, x2, out=out, where=where, dtype=dtype) + + +def _remainder(x1, x2, C_style=False): + """Computes remainder without applying keyword arguments.""" + dtype = _promote(F.dtype(x1), F.dtype(x2)) + if not _check_is_float(dtype): + x1 = F.cast(x1, mstype.float32) + x2 = F.cast(x2, mstype.float32) + + quotient = F.tensor_div(x1, x2) + if C_style: + quotient = fix(quotient) + else: + quotient = F.floor(quotient) + prod = F.tensor_mul(x2, quotient) + res = F.tensor_sub(x1, prod) + if _check_is_int(dtype): + zeros_tensor = zeros(F.shape(quotient), F.dtype(quotient)) + x2_zeros = F.equal(x2, zeros_tensor) + res = F.select(x2_zeros, zeros_tensor, res) + + if not _check_same_type(F.dtype(res), dtype): + res = F.cast(res, dtype) + return res + + +def remainder(x1, x2, out=None, where=True, dtype=None): + """ + Returns element-wise remainder of division. + + Computes the remainder complementary to the floor_divide function. It is + equivalent to the Python modulus operator ``x1 % x2`` and has the same sign + as the divisor `x2`. The MATLAB function equivalent to np.remainder is mod. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor): input array. + x2 (Tensor): input array. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the element-wise remainder of the quotient + ``floor_divide(x1, x2)``. This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.remainder(np.array([4, 7]), np.array([2, 3])) + >>> print(output) + [0 1] + >>> output = np.remainder(np.arange(7), np.array(5)) + >>> print(output) + [0 1 2 3 4 0 1] + """ + return _apply_tensor_op(_remainder, x1, x2, out=out, where=where, dtype=dtype) + + +def fix(x): + """ + Rounds to nearest integer towards zero. + + Rounds an array of floats element-wise to nearest integer towards zero. The + rounded values are returned as floats. + + Note: + Numpy argument `out` is not supported. + + Args: + x (Tensor): An array of floats to be rounded. + + Returns: + Tensor. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.fix(np.array([2.1, 2.9, -2.1, -2.9])) + >>> print(output) + [ 2. 2. -2. -2.] + """ + _check_input_tensor(x) + if not _check_is_float(F.dtype(x)): + x = F.cast(x, mstype.float32) + floored = F.floor(x) + # TODO change to F.ceil once supported on CPU. + ceiled = F.neg_tensor(F.floor(F.neg_tensor(x))) + is_neg = F.tensor_lt(x, zeros(F.shape(x), F.dtype(x))) + return F.select(is_neg, ceiled, floored) + + +def fmod(x1, x2, out=None, where=True, dtype=None): + """ + Returns the element-wise remainder of division. + + This is the NumPy implementation of the C library function fmod, the remainder + has the same sign as the dividend `x1`. It is equivalent to the Matlab(TM) rem + function and should not be confused with the Python modulus operator ``x1 % x2``. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x1 (Tensor) + x2 (Tensor): input arrays. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the remainder of the division of `x1` by `x2`. This is a + scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.fmod(np.array([-3, -2, -1, 1, 2, 3]), np.array(2)) + >>> print(output) + [-1 0 -1 1 0 1] + """ + return _apply_tensor_op(lambda x1, x2: _remainder(x1, x2, C_style=True), x1, x2, + out=out, where=where, dtype=dtype) + + +def trunc(x, out=None, where=True, dtype=None): + """ + Returns the element-wise remainder of division. + + This is the NumPy implementation of the C library function fmod, the remainder + has the same sign as the dividend `x1`. It is equivalent to the Matlab(TM) rem + function and should not be confused with the Python modulus operator ``x1 % x2``. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + + Args: + x (Tensor): input data. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, the remainder of the division of `x1` by `x2`. This is a + scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.trunc(np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])) + >>> print(output) + [-1. -1. -0. 0. 1. 1. 2.] + """ + return _apply_tensor_op(fix, x, out=out, where=where, dtype=dtype) + + +def exp(x, out=None, where=True, dtype=None): + """ + Calculates the exponential of all elements in the input array. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, np.float32, np.float64. + + Args: + x (Tensor): input data. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise exponential of `x`. This is a scalar if both + `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.exp(np.arange(5).astype(np.float32)) + >>> print(output) + [ 1. 2.718282 7.3890557 20.085537 54.598145 ] + """ + return _apply_tensor_op(F.tensor_exp, x, out=out, where=where, dtype=dtype) + + +def expm1(x, out=None, where=True, dtype=None): + """ + Calculates ``exp(x) - 1`` for all elements in the array. + + Note: + Numpy arguments `casting`, `order`, `dtype`, `subok`, `signature`, and `extobj` are + not supported. + When `where` is provided, `out` must have a tensor value. `out` is not supported + for storing the result, however it can be used in combination with `where` to set + the value at indices for which `where` is set to False. + On GPU, the supported dtypes are np.float16, and np.float32. + On CPU, the supported dtypes are np.float16, and np.float32. + + Args: + x (Tensor): input data. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor or scalar, element-wise exponential minus one, ``out = exp(x) - 1``. + This is a scalar if both `x1` and `x2` are scalars. + + Raises: + TypeError: if the input is not a tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> output = np.expm1(np.arange(5).astype(np.float32)) + >>> print(output) + [ 0. 1.7182819 6.389056 19.085537 53.59815 ] + """ + return _apply_tensor_op(F.tensor_expm1, x, out=out, where=where, dtype=dtype) + + +@constexpr +def _real_axes(ndim_orig, ndim_out, axes_orig): + """Returns the real axes to be reduced after performing broadcast""" + diff = ndim_out - ndim_orig + axes = F.make_range(diff) + axes_orig = map(functools.partial(operator.add, diff), axes_orig) + return axes + tuple(axes_orig) + + +@constexpr +def _shape_reduced_keepdims(shape, axes): + """ + Reduces dimensions corresponding to argument axes while + keeping the number of dimensions unchanged. + """ + ndim_out = F.tuple_len(shape) + shape_out = [1]*ndim_out + for i in range(ndim_out): + if not i in axes: + shape_out[i] = shape[i] + return tuple(shape_out) + + +@constexpr +def _shape_reduced(shape, axes): + """Removes dimensions corresponding to argument axes""" + ndim_orig = F.tuple_len(shape) + ndim_out = ndim_orig - F.tuple_len(axes) + shape_out = [0]*ndim_out + idx_out = 0 + for i in range(ndim_orig): if not i in axes: shape_out[idx_out] = shape[i] idx_out += 1 @@ -712,22 +2252,137 @@ def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b): return shape_rem +def _reduce(a, reduce_fn, cmp_fn, axis=None, keepdims=False, initial=None, where=True): + """Applies comparison based on cmp_fn and reduction based on reduce_fn""" + _check_input_tensor(a) + + shape = F.shape(a) + ndim = F.rank(a) + dtype = F.dtype(a) + axes = _check_axis_valid(axis, ndim) + + if _is_shape_empty(shape): + if not axes: + return a + if keepdims: + shape_out = _shape_reduced_keepdims(shape, axes) + else: + shape_out = _shape_reduced(shape, axes) + if _is_shape_empty(shape_out): + return empty(F.dtype(a), shape_out) + if initial is None: + return _raise_value_error('initial value must be provided for zero-size arrays') + return full(shape_out, initial, dtype) + + if initial is not None: + initial = full(shape, initial, dtype) + a = cmp_fn(a, initial) + if not axes: + return a + if isinstance(where, Tensor): + if initial is None: + return _raise_value_error('initial value must be provided for where masks') + ndim_orig = F.rank(a) + a = where_(where, a, initial) + axes = _real_axes(ndim_orig, F.rank(a), axes) + + return reduce_fn(a, axes) + + +def positive(a, out=None, where=True, dtype=None): + """ + Numerical positive, element-wise. + + Note: + Numpy arguments casting, order, subok, signature, and extobj are + not supported. + + Args: + a (Tensor): Input tensor. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, -1]) + >>> output = np.positive(a) + >>> print(output) + [1, -1] + """ + _check_input_tensor(a) + neg_tensor = F.neg_tensor(a) + return _apply_tensor_op(F.neg_tensor, neg_tensor, out=out, where=where, dtype=dtype) + + +def negative(a, out=None, where=True, dtype=None): + """ + Numerical negative, element-wise. + + Note: + Numpy arguments `casting`, `order`, `subok`, `signature`, and `extobj` are + not supported. + + Args: + a (Tensor): Input tensor. + out (Tensor or None, optional): defaults to None. + where (Tensor or None, optional): For any non-default value of type other + than :class:`Tensor` or :class:`None`, the output retains its original value. + This condition is broadcasted over the input. At locations where the + condition is `True`, the out array will be set to the ufunc result. + Elsewhere, the out array will retain its original value. Note that + if an uninitialized out array is created via the default ``out=None``, + locations within it where the condition is `False` will remain + uninitialized. + dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the + output Tensor. + + Returns: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.numpy as np + >>> a = np.asarray([1, -1]) + >>> output = np.negative(a) + >>> print(output) + [-1, 1] + """ + _check_input_tensor(a) + return _apply_tensor_op(F.neg_tensor, a, out=out, where=where, dtype=dtype) + + def _apply_tensor_op(fn, *args, out=None, where=True, dtype=None): - """applies tensor operations based on fn""" - for arg in args: - _check_input_tensor(F.typeof(arg)) + """Applies tensor operations based on fn""" + _check_input_tensor(*args) res = fn(*args) # if out is set to a non-default value, return tensor will have the same # dtype as out, which overrides the dtype passed into the keyword argument - if _check_is_tensor(F.typeof(out)): + if isinstance(out, Tensor): dtype_out = F.dtype(out) elif dtype is not None: dtype_out = dtype else: dtype_out = F.dtype(res) - if _check_is_tensor(F.typeof(out)) and _check_is_tensor(F.typeof(where)): + if isinstance(out, Tensor) and isinstance(where, Tensor): out = where_(where, res, out) elif out is None or where is not None: out = res diff --git a/mindspore/numpy/utils.py b/mindspore/numpy/utils.py index 54e78a485e..eed0c7f492 100644 --- a/mindspore/numpy/utils.py +++ b/mindspore/numpy/utils.py @@ -16,11 +16,11 @@ import numpy as onp -import mindspore.context as context from ..common import Tensor from ..ops import functional as F +from ..common import dtype as mstype -from .utils_const import _tile_size +from .utils_const import _tile_size, _add_unit_axes, _raise_type_error def _deep_list(array_like): @@ -56,10 +56,9 @@ def _deep_tensor_to_nparray(array_like): def _check_input_for_asarray(array_like): """check whether array_like argument is a valid type for np.asarray conversion""" - if isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): - return True - raise TypeError("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ - f"or numpy.ndarray, but got {type(array_like)}") + if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)): + _raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \ + "or numpy.ndarray, but got ", array_like) def _is_scalar(shape): @@ -67,16 +66,6 @@ def _is_scalar(shape): return F.shape_mul(shape) == 1 -def _is_empty(shape): - """Checks if the shape is empty""" - return F.shape_mul(shape) == 0 - - -def _get_device(): - """Get the current device (`GPU`, `CPU`, `Ascend`)""" - return context.get_context('device_target') - - def _convert_list_tensor_to_tuple_tensor(list_of_tensor): """Convert a list of tensor to a tuple of tensor""" if isinstance(list_of_tensor, list): @@ -87,19 +76,66 @@ def _convert_list_tensor_to_tuple_tensor(list_of_tensor): return list_of_tensor -def _get_mode(): - """Get the current mode (0 is Graph mode, 1 is PyNative mode)""" - return context.get_context('mode') - - def _expand(x, ndim, axis=0): - """Expand x to ndim.""" - while F.rank(x) < ndim: - x = F.expand_dims(x, axis) - return x + """Expand x to ndim from axis, which can be 0 or -1.""" + shape = _add_unit_axes(F.shape(x), ndim, axis == -1) + return F.reshape(x, shape) def _broadcast_to(x, shape_cur, shape_to, ndim_to): """Broadcasts x from shape_cur to shape_to.""" size = _tile_size(shape_cur, shape_to, ndim_to) return F.tile(x, size) + + +def _broadcast_to_shape(x, shape): + """Broadcasts x from current shape to shape""" + ndim_to = len(shape) + x = _expand(x, ndim_to) + return _broadcast_to(x, F.shape(x), shape, ndim_to) + + +def _get_size(x, axis=None): + """Get the number of elements along the given axis of tensor x.""" + if axis is None or F.tuple_len(axis) == 0: + axis = F.make_range(x.ndim) + nums = 1 + for ax in axis: + nums *= x.shape[ax] + return nums + + +def _check_input_tensor(*tensors): + for tensor in tensors: + if not isinstance(tensor, Tensor): + _raise_type_error('expect Tensor, but got ', F.typeof(tensor)) + return True + + +def _convert_64_to_32(tensor): + """Convert tensor with float64/int64 types to float32/int32.""" + if tensor.dtype == mstype.float64: + return tensor.astype("float32") + if tensor.dtype == mstype.int64: + return tensor.astype("int32") + return tensor + + +def _get_dtype_from_scalar(*input_numbers): + """ + Get the final dtype from series of input numbers, compared with F.typeof, we + return int32/float32 for python int/float instead. + """ + bool_flag = True + int_flag = True + for number in input_numbers: + if number is not None: + if not isinstance(number, bool): + bool_flag = False + if not isinstance(number, int): + int_flag = False + if bool_flag: + return mstype.bool_ + if int_flag: + return mstype.int32 + return mstype.float32 diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py index 7c5544c2ae..e8ccd223e5 100644 --- a/mindspore/numpy/utils_const.py +++ b/mindspore/numpy/utils_const.py @@ -13,14 +13,16 @@ # limitations under the License. # ============================================================================ """internal graph-compatible utility functions""" +import math from functools import partial import mindspore.context as context from ..ops import functional as F from ..ops.primitive import constexpr from ..common import dtype as mstype +from ..common import Tensor from .._c_expression import Tensor as Tensor_ -from .._c_expression.typing import Tuple, List +from .._c_expression import typing from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map @@ -28,12 +30,17 @@ from .dtypes import promotion_rule, dtype_tuple, all_types, dtype_map @constexpr def _check_shape(shape): """check the shape param to match the numpy style""" - if not isinstance(shape, (int, tuple, list, Tuple, List)): + if not isinstance(shape, (int, tuple, list, typing.Tuple, typing.List)): raise TypeError(f"only int, tuple and list are allowed for shape, but got {type(shape)}") if isinstance(shape, int): shape = (shape,) - if isinstance(shape, (list, List)): + if isinstance(shape, (list, typing.List)): shape = tuple(shape) + for s in shape: + if not isinstance(s, int): + raise TypeError("each entry in shape should be int.") + if s < 0: + raise ValueError("each entry in shape should no less than 0.") return shape @@ -57,7 +64,7 @@ def _check_dtype(dtype): @constexpr -def _check_shape_contain_zero(shp): +def _is_shape_empty(shp): """Check whether shape contains zero""" if isinstance(shp, int): return shp == 0 @@ -77,35 +84,28 @@ def _check_start_normalize(start, ndim): @constexpr def _check_axes_range(axes, ndim): """ - Check axes are within the number of dimensions of tensor x and normalize the negative axes. + Check axes type and normalize the negative axes. + Args: - axes (Union[int, tuple(int), list(int)]): Axes of the tensor. + axes: Axes of the tensor. ndim (int): The number of dimensions of the tensor. + Return: Axes (Union[int, tuple(int)]). If input is integer, return integer, else tuple. + + Raises: + TypeError: If the axes are not integer, tuple(int) or list(int). + ValueError: If duplicate axes exists or some axis is out of bounds. """ - if not isinstance(axes, int) and not isinstance(axes, tuple) and not isinstance(axes, list): - raise TypeError(f"int, tuple(int) or list(int) expected, but got {type(axes)}.") - low = -ndim - up = ndim - 1 - if low > up: - raise ValueError(f"Lower bound {low} and upper bound {up} of axes are not allowed.") - if isinstance(axes, int): - if axes < low or axes > up: - raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}.") - return axes if axes >= 0 else axes + ndim - new_axes = [] - for item in axes: - if not isinstance(item, int): - raise TypeError(f"int in tuple or list expected, but got {type(item)}.") - if item < low or item > up: - raise ValueError(f"axis {item} in {axes} is out of bounds for tensor of dimension {ndim}.") - new_axes.append(item if item >= 0 else item + ndim) - return tuple(new_axes) + _check_axis_type(axes, True, True, True) + if isinstance(axes, (list, tuple)): + _check_element_int(axes) + axes = _canonicalize_axis(axes, ndim) + return axes @constexpr -def _get_device_compile(): +def _get_device(): """Get the current device (`GPU`, `CPU`, `Ascend`)""" return context.get_context('device_target') @@ -153,9 +153,10 @@ def _infer_out_shape(*shapes): @constexpr def _check_axis_in_range(axis, ndim): """Checks axes are with the bounds of ndim""" - if -ndim <= axis < ndim: - return True - raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') + if not isinstance(axis, int): + raise TypeError(f'axes should be integers, not {type(axis)}') + if not -ndim <= axis < ndim: + raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') @constexpr @@ -165,26 +166,25 @@ def _check_axis_valid(axes, ndim): to the built-in operator (non-negative, int or tuple) """ if isinstance(axes, int): - _ = _check_axis_in_range(axes, ndim) + _check_axis_in_range(axes, ndim) return (axes % ndim,) - if isinstance(axes, tuple): + if isinstance(axes, (tuple, list)): for axis in axes: - _ = _check_axis_in_range(axis, ndim) + _check_axis_in_range(axis, ndim) axes = tuple(map(lambda x: x % ndim, axes)) if all(axes.count(el) <= 1 for el in axes): return axes if axes is None: axes = F.make_range(ndim) return axes - raise ValueError('duplicate value in \'axis\'') + raise ValueError('duplicate value in "axis"') @constexpr def _check_shape_aligned(shape1, shape2): """Checks shape1 and shape2 are valid shapes to perform inner product""" - if shape1[-1] == shape2[-1]: - return True - raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') + if shape1[-1] != shape2[-1]: + raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') @constexpr @@ -197,30 +197,6 @@ def _tile_size(shape, out_shape, ndim): return tuple(size) -@constexpr -def _check_is_int(obj): - """Check whether obj is an integer.""" - return isinstance(obj, int) - - -@constexpr -def _check_is_tuple(obj): - """Check whether obj is a tuple""" - return isinstance(obj, (tuple, Tuple)) - - -@constexpr -def _check_is_list(obj): - """Check whether obj is a list""" - return isinstance(obj, (list, List)) - - -@constexpr -def _check_is_tensor(obj): - """Check whether obj is a tensor""" - return isinstance(obj, mstype.tensor_type) - - @constexpr def _raise_type_error(info, param=None): """ @@ -298,6 +274,177 @@ def _check_is_float(dtype): @constexpr -def _check_input_tensor(input_type): - if not _check_is_tensor(input_type): - raise TypeError(f'expect Tensor, but got {input_type}') +def _check_is_int(dtype): + return isinstance(dtype, typing.Int) + + +@constexpr +def _check_matmul_shapes(shape1, shape2): + """Checks shape1 and shape2 are valid shapes to perform matmul""" + ndim1, ndim2 = len(shape1), len(shape2) + if ndim1 < 1 or ndim2 < 1: + raise ValueError('input operands must have at least 1 dimension') + if ndim2 >= 2 and shape1[-1] != shape2[-2]: + raise ValueError(f'mismatch in core dimension of input operands (size ' + f'{shape1[-1]} is different from {shape2[-2]})') + + +@constexpr +def _check_axis_type(axis, type_int=True, type_tuple=True, type_list=True): + """Check axis argument type.""" + if type_int and isinstance(axis, int): + return True + if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)): + for ax in axis: + if not isinstance(ax, int): + raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axis}.") + return True + + type_str = "" + if type_int: type_str += "int, " + if type_tuple: type_str += "tuple, " + if type_list: type_str += "list, " + raise TypeError(f"Axis should be {type_str}but got {type(axis)}.") + + +@constexpr +def _canonicalize_axis(axis, ndim): + """ + Check axes are within the number of dimensions of tensor x and normalize the negative axes. + Args: + axis (Union[int, tuple(int), list(int)]): Axes of the tensor. + ndim (int): The number of dimensions of the tensor. + Return: + Axis (Union[int, tuple(int)]). If input is integer, return integer, else tuple. + """ + if isinstance(axis, int): + axis = [axis] + for ax in axis: + _check_axis_in_range(ax, ndim) + + def canonicalizer(ax): + return ax + ndim if ax < 0 else ax + + axis = tuple([canonicalizer(axis) for axis in axis]) + if all(axis.count(el) <= 1 for el in axis): + return axis if len(axis) > 1 else axis[0] + raise ValueError(f"duplicate axes in {axis}.") + + +@constexpr +def _broadcast_tuples(tup1, tup2): + """ + Broadcast two 1D tuples to the same length, if inputs are ints, convert to + tuples first. + """ + tup1 = (tup1,) if isinstance(tup1, int) else tup1 + tup2 = (tup2,) if isinstance(tup2, int) else tup2 + if not isinstance(tup1, (tuple, list)) or not isinstance(tup2, (tuple, list)): + raise TypeError("input shift and axis must be tuple or list or int.") + if len(tup1) == len(tup2): + return tup1, tup2 + if len(tup1) == 1: + tup1 *= len(tup2) + elif len(tup2) == 1: + tup2 *= len(tup1) + else: + raise ValueError("shape mismatch: objects cannot be broadcast to a single shape") + return tup1, tup2 + + +@constexpr +def _expanded_shape(ndim, axis_size, axis): + """ + Returns a shape with size = 1 for all dimensions + except at axis. + """ + return tuple([axis_size if i == axis else 1 for i in range(ndim)]) + + +@constexpr +def _add_unit_axes(shape, ndim, append=False): + """ + Prepends shape with 1s so that it has the number of dimensions ndim. + If append is set to True, returns shape appended with 1s instead. + """ + if isinstance(shape, int): + shape = (shape,) + ndim_diff = ndim - len(shape) + if ndim_diff > 0: + if append: + shape = [i for i in shape] + [1]*ndim_diff + else: + shape = [1]*ndim_diff + [i for i in shape] + return tuple(shape) + + +@constexpr +def _check_element_int(lst): + """ + Check whether each element in `lst` is an integer. + """ + for item in lst: + if not isinstance(item, int): + raise TypeError(f"Each element in {lst} should be integer, but got {type(item)}.") + return True + + +@constexpr +def _type_convert(force, obj): + """ + Convert type of `obj` to `force`. + """ + return force(obj) + + +@constexpr +def _list_comprehensions(obj, item=None, return_tuple=False): + """ + Generates a new list/tuple by list comprehension. + + Args: + obj (Union[int, list, tuple]): + If integer, it will be the length of the returned tuple/list. + item: The value to be filled. Default: None. + If None, the values in the new list/tuple are the same as obj + or range(obj) when obj is integer. + return_tuple(bool): If true, returns tuple, else returns list. + + Returns: + List or tuple. + """ + res = [] + lst = obj + if isinstance(obj, int): + lst = range(obj) + if item is None: + res = [i for i in lst] + else: + res = [item for i in lst] + if return_tuple: + return tuple(res) + return res + + +@constexpr +def _tuple_getitem(tup, idx, startswith=True): + """ + Returns a slice from tup starting with idx. If startswith is False, + returns a lice from tup ending with idx instead. + """ + if startswith: + return tup[idx:] + return tup[:idx] + + +@constexpr +def _iota(dtype, num): + """Creates a 1-D tensor with value: [0,1,...num-1] and dtype.""" + # TODO: Change to P.Linspace when the kernel is implemented on CPU. + return Tensor(list(range(int(num))), dtype) + + +@constexpr +def _ceil(number): + """Ceils the number in graph mode.""" + return math.ceil(number) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 670af6fab0..bf283d6eda 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -59,18 +59,25 @@ tensor_div = P.RealDiv() tensor_floordiv = P.FloorDiv() tensor_pow = P.Pow() tensor_mod = P.FloorMod() +tensor_exp = P.Exp() +tensor_expm1 = P.Expm1() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() check_bprop = P.CheckBprop() equal = P.Equal() not_equal = P.NotEqual() +isfinite = P.IsFinite() assign_sub = P.AssignSub() assign_add = P.AssignAdd() assign = P.Assign() square = P.Square() sqrt = P.Sqrt() +log = P.Log() reduce_sum = P.ReduceSum() tensor_slice = P.Slice() +maximum = P.Maximum() +minimum = P.Minimum() +floor = P.Floor() scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() @@ -82,6 +89,7 @@ transpose = P.Transpose() squeeze = P.Squeeze() scatter_nd = P.ScatterNd() gather = P.Gather() +gather_d = P.GatherD() gather_nd = P.GatherNd() scatter_update = P.ScatterUpdate() scatter_nd_update = P.ScatterNdUpdate() diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py index b60e4027be..c10a5d3847 100644 --- a/tests/st/numpy_native/test_array_creations.py +++ b/tests/st/numpy_native/test_array_creations.py @@ -14,15 +14,13 @@ # ============================================================================ """unit tests for numpy array operations""" -import functools - import pytest import numpy as onp -import mindspore.context as context import mindspore.numpy as mnp -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') +from .utils import rand_int, rand_bool, match_array, match_res, match_meta, \ + match_all_arrays class Cases(): @@ -97,10 +95,10 @@ class Cases(): self.mnp_prototypes = [ mnp.ones((2, 3, 4)), mnp.ones((0, 3, 0, 2, 5)), - onp.ones((2, 7, 0)), - onp.ones(()), - [mnp.ones(3), (1, 2, 3), onp.ones(3), [4, 5, 6]], - ([(1, 2), mnp.ones(2)], (onp.ones(2), [3, 4])), + mnp.ones((2, 7, 0)), + mnp.ones(()), + [mnp.ones(3), (1, 2, 3), mnp.ones(3), [4, 5, 6]], + ([(1, 2), mnp.ones(2)], (mnp.ones(2), [3, 4])), ] self.onp_prototypes = [ @@ -113,97 +111,6 @@ class Cases(): ] -def match_array(actual, expected, error=0): - if error > 0: - onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), - decimal=error) - else: - onp.testing.assert_equal(actual.tolist(), expected.tolist()) - - -def check_all_results(onp_results, mnp_results, error=0): - """Check all results from numpy and mindspore.numpy""" - for i, _ in enumerate(onp_results): - match_array(onp_results[i], mnp_results[i].asnumpy()) - - -def check_all_unique_results(onp_results, mnp_results): - """ - Check all results from numpy and mindspore.numpy. - - Args: - onp_results (Union[tuple of numpy.arrays, numpy.array]) - mnp_results (Union[tuple of Tensors, Tensor]) - """ - for i, _ in enumerate(onp_results): - if isinstance(onp_results[i], tuple): - for j in range(len(onp_results[i])): - match_array(onp_results[i][j], - mnp_results[i][j].asnumpy(), error=7) - else: - match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) - - -def run_non_kw_test(mnp_fn, onp_fn): - """Run tests on functions with non keyword arguments""" - test_case = Cases() - for i in range(len(test_case.arrs)): - arrs = test_case.arrs[:i] - match_res(mnp_fn, onp_fn, *arrs) - - for i in range(len(test_case.scalars)): - arrs = test_case.scalars[:i] - match_res(mnp_fn, onp_fn, *arrs) - - for i in range(len(test_case.expanded_arrs)): - arrs = test_case.expanded_arrs[:i] - match_res(mnp_fn, onp_fn, *arrs) - - for i in range(len(test_case.nested_arrs)): - arrs = test_case.nested_arrs[:i] - match_res(mnp_fn, onp_fn, *arrs) - - -def rand_int(*shape): - """return an random integer array with parameter shape""" - res = onp.random.randint(low=1, high=5, size=shape) - if isinstance(res, onp.ndarray): - return res.astype(onp.float32) - return float(res) - - -# return an random boolean array -def rand_bool(*shape): - return onp.random.rand(*shape) > 0.5 - - -def match_res(mnp_fn, onp_fn, *arrs, **kwargs): - """Checks results from applying mnp_fn and onp_fn on arrs respectively""" - mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) - mnp_res = mnp_fn(*mnp_arrs, **kwargs) - onp_res = onp_fn(*arrs, **kwargs) - match_all_arrays(mnp_res, onp_res) - - -def match_all_arrays(mnp_res, onp_res, error=0): - if isinstance(mnp_res, (tuple, list)): - for actual, expected in zip(mnp_res, onp_res): - match_array(actual.asnumpy(), expected, error) - else: - match_array(mnp_res.asnumpy(), onp_res, error) - - -def match_meta(actual, expected): - # float64 and int64 are not supported, and the default type for - # float and int are float32 and int32, respectively - if expected.dtype == onp.float64: - expected = expected.astype(onp.float32) - elif expected.dtype == onp.int64: - expected = expected.astype(onp.int32) - assert actual.shape == expected.shape - assert actual.dtype == expected.dtype - - @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -440,27 +347,50 @@ def test_arange(): def test_linspace(): actual = onp.linspace(2.0, 3.0, dtype=onp.float32) expected = mnp.linspace(2.0, 3.0).asnumpy() - match_array(actual, expected, error=7) + match_array(actual, expected, error=6) actual = onp.linspace(2.0, 3.0, num=5, dtype=onp.float32) expected = mnp.linspace(2.0, 3.0, num=5).asnumpy() - match_array(actual, expected, error=7) + match_array(actual, expected, error=6) actual = onp.linspace( 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) expected = mnp.linspace(2.0, 3.0, num=5, endpoint=False).asnumpy() - match_array(actual, expected, error=7) + match_array(actual, expected, error=6) actual = onp.linspace(2.0, 3.0, num=5, retstep=True, dtype=onp.float32) expected = mnp.linspace(2.0, 3.0, num=5, retstep=True) match_array(actual[0], expected[0].asnumpy()) - assert actual[1] == expected[1] + assert actual[1] == expected[1].asnumpy() actual = onp.linspace(2.0, [3, 4, 5], num=5, endpoint=False, dtype=onp.float32) expected = mnp.linspace( 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() - match_array(actual, expected) + match_array(actual, expected, error=6) + + start = onp.random.random([2, 1, 4]) + stop = onp.random.random([1, 5, 1]) + actual = onp.linspace(start, stop, num=20, retstep=True, + endpoint=False, dtype=onp.float32) + expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, + retstep=True, endpoint=False) + match_array(actual[0], expected[0].asnumpy(), error=6) + match_array(actual[1], expected[1].asnumpy(), error=6) + + actual = onp.linspace(start, stop, num=20, retstep=True, + endpoint=False, dtype=onp.int16) + expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, + retstep=True, endpoint=False, dtype=mnp.int16) + match_array(actual[0], expected[0].asnumpy(), error=6) + match_array(actual[1], expected[1].asnumpy(), error=6) + + for axis in range(2): + actual = onp.linspace(start, stop, num=20, retstep=False, + endpoint=False, dtype=onp.float32, axis=axis) + expected = mnp.linspace(mnp.asarray(start), mnp.asarray(stop), num=20, + retstep=False, endpoint=False, dtype=mnp.float32, axis=axis) + match_array(actual, expected.asnumpy(), error=6) @pytest.mark.level1 @@ -472,22 +402,22 @@ def test_linspace(): def test_logspace(): actual = onp.logspace(2.0, 3.0, dtype=onp.float32) expected = mnp.logspace(2.0, 3.0).asnumpy() - match_array(actual, expected) + match_array(actual, expected, error=3) actual = onp.logspace(2.0, 3.0, num=5, dtype=onp.float32) expected = mnp.logspace(2.0, 3.0, num=5).asnumpy() - match_array(actual, expected) + match_array(actual, expected, error=3) actual = onp.logspace( 2.0, 3.0, num=5, endpoint=False, dtype=onp.float32) expected = mnp.logspace(2.0, 3.0, num=5, endpoint=False).asnumpy() - match_array(actual, expected) + match_array(actual, expected, error=3) - actual = onp.logspace(2.0, [3, 4, 5], num=5, + actual = onp.logspace(2.0, [3, 4, 5], num=5, base=2, endpoint=False, dtype=onp.float32) expected = mnp.logspace( - 2.0, [3, 4, 5], num=5, endpoint=False).asnumpy() - match_array(actual, expected) + 2.0, [3, 4, 5], num=5, base=2, endpoint=False).asnumpy() + match_array(actual, expected, error=3) @pytest.mark.level1 @@ -537,7 +467,6 @@ def run_x_like(mnp_fn, onp_fn): actual = mnp_fn(mnp_proto, shape=shape).asnumpy() expected = onp_fn(onp_proto, shape=shape) match_array(actual, expected) - for mnp_dtype, onp_dtype in zip(test_case.mnp_dtypes, test_case.onp_dtypes): actual = mnp_fn(mnp_proto, dtype=mnp_dtype).asnumpy() @@ -581,18 +510,18 @@ def test_full_like(): for mnp_proto, onp_proto in zip(test_case.mnp_prototypes, test_case.onp_prototypes): shape = onp.zeros_like(onp_proto).shape fill_value = rand_int() - actual = mnp.full_like(mnp_proto, fill_value).asnumpy() + actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() expected = onp.full_like(onp_proto, fill_value) match_array(actual, expected) for i in range(len(shape) - 1, 0, -1): fill_value = rand_int(*shape[i:]) - actual = mnp.full_like(mnp_proto, fill_value).asnumpy() + actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() expected = onp.full_like(onp_proto, fill_value) match_array(actual, expected) fill_value = rand_int(1, *shape[i + 1:]) - actual = mnp.full_like(mnp_proto, fill_value).asnumpy() + actual = mnp.full_like(mnp_proto, mnp.array(fill_value)).asnumpy() expected = onp.full_like(onp_proto, fill_value) match_array(actual, expected) @@ -620,6 +549,26 @@ def test_tri_triu_tril(): match_array(mnp.tri(64, 64, -10).asnumpy(), onp.tri(64, 64, -10)) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cumsum(): + x = mnp.ones((16, 16), dtype="bool") + match_array(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) + match_array(mnp.cumsum(x, axis=0).asnumpy(), + onp.cumsum(x.asnumpy(), axis=0)) + match_meta(mnp.cumsum(x).asnumpy(), onp.cumsum(x.asnumpy())) + + x = rand_int(3, 4, 5) + match_array(mnp.cumsum(mnp.asarray(x), dtype="bool").asnumpy(), + onp.cumsum(x, dtype="bool")) + match_array(mnp.cumsum(mnp.asarray(x), axis=-1).asnumpy(), + onp.cumsum(x, axis=-1)) + + def mnp_diagonal(arr): return mnp.diagonal(arr, offset=2, axis1=-1, axis2=0) @@ -697,6 +646,138 @@ def test_trace(): match_res(mnp.trace, onp.trace, arr, offset=i, axis1=2, axis2=-1) +def mnp_meshgrid(*xi): + a = mnp.meshgrid(*xi) + b = mnp.meshgrid(*xi, sparse=True) + c = mnp.meshgrid(*xi, indexing='ij') + d = mnp.meshgrid(*xi, sparse=False, indexing='ij') + return a, b, c, d + + +def onp_meshgrid(*xi): + a = onp.meshgrid(*xi) + b = onp.meshgrid(*xi, sparse=True) + c = onp.meshgrid(*xi, indexing='ij') + d = onp.meshgrid(*xi, sparse=False, indexing='ij') + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_meshgrid(): + xi = (onp.full(3, 2), onp.full(1, 5), onp.full( + (2, 3), 9), onp.full((4, 5, 6), 7)) + for i in range(len(xi)): + arrs = xi[i:] + mnp_arrs = map(mnp.asarray, arrs) + for mnp_res, onp_res in zip(mnp_meshgrid(*mnp_arrs), onp_meshgrid(*arrs)): + match_all_arrays(mnp_res, onp_res) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_mgrid(): + mnp_res = mnp.mgrid[0:5] + onp_res = onp.mgrid[0:5] + match_all_arrays(mnp_res, onp_res, error=5) + + mnp_res = mnp.mgrid[2:30:4j, -10:20:7, 2:5:0.5] + onp_res = onp.mgrid[2:30:4j, -10:20:7, 2:5:0.5] + match_all_arrays(mnp_res, onp_res, error=5) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ogrid(): + mnp_res = mnp.ogrid[0:5] + onp_res = onp.ogrid[0:5] + match_all_arrays(mnp_res, onp_res, error=5) + + mnp_res = mnp.ogrid[2:30:4j, -10:20:7, 2:5:0.5] + onp_res = onp.ogrid[2:30:4j, -10:20:7, 2:5:0.5] + match_all_arrays(mnp_res, onp_res, error=5) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_diagflat(): + arrs = [rand_int(0), rand_int(2, 3), rand_int(3, 5, 0)] + for arr in arrs: + for i in [-2, 0, 7]: + match_res(mnp.diagflat, onp.diagflat, arr, k=i) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_diag(): + arrs = [rand_int(0), rand_int(0, 0), rand_int(7), rand_int(5, 5), + rand_int(3, 8), rand_int(9, 6)] + for arr in arrs: + for i in [-10, -5, -1, 0, 2, 5, 6, 10]: + match_res(mnp.diag, onp.diag, arr, k=i) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_diag_indices(): + mnp_res = mnp.diag_indices(0) + onp_res = onp.diag_indices(0) + match_all_arrays(mnp_res, onp_res) + + mnp_res = mnp.diag_indices(3, 0) + onp_res = onp.diag_indices(3, 0) + match_all_arrays(mnp_res, onp_res) + + mnp_res = mnp.diag_indices(5, 7) + onp_res = onp.diag_indices(5, 7) + match_all_arrays(mnp_res, onp_res) + + +def mnp_ix_(*args): + return mnp.ix_(*args) + + +def onp_ix_(*args): + return onp.ix_(*args) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ix_(): + arrs = [rand_int(i + 1) for i in range(10)] + for i in range(10): + test_arrs = arrs[:i + 1] + match_res(mnp_ix_, onp_ix_, *test_arrs) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training diff --git a/tests/st/numpy_native/test_array_ops.py b/tests/st/numpy_native/test_array_ops.py index c27bcdc855..d6dbd86fb7 100644 --- a/tests/st/numpy_native/test_array_ops.py +++ b/tests/st/numpy_native/test_array_ops.py @@ -22,6 +22,9 @@ import numpy as onp import mindspore.numpy as mnp from mindspore.nn import Cell +from .utils import rand_int, run_non_kw_test, check_all_results, match_array, \ + rand_bool, match_res, run_multi_test + class Cases(): def __init__(self): @@ -111,81 +114,6 @@ class Cases(): ] -def match_array(actual, expected, error=0): - if error > 0: - onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), - decimal=error) - else: - onp.testing.assert_equal(actual.tolist(), expected.tolist()) - - -def check_all_results(onp_results, mnp_results, error=0): - """Check all results from numpy and mindspore.numpy""" - for i, _ in enumerate(onp_results): - match_array(onp_results[i], mnp_results[i].asnumpy()) - - -def run_non_kw_test(mnp_fn, onp_fn): - """Run tests on functions with non keyword arguments""" - test_case = Cases() - for i in range(len(test_case.arrs)): - arrs = test_case.arrs[:i] - match_res(mnp_fn, onp_fn, *arrs) - - for i in range(len(test_case.scalars)): - arrs = test_case.scalars[:i] - match_res(mnp_fn, onp_fn, *arrs) - - for i in range(len(test_case.expanded_arrs)): - arrs = test_case.expanded_arrs[:i] - match_res(mnp_fn, onp_fn, *arrs) - - for i in range(len(test_case.nested_arrs)): - arrs = test_case.nested_arrs[:i] - match_res(mnp_fn, onp_fn, *arrs) - - -def rand_int(*shape): - """return an random integer array with parameter shape""" - res = onp.random.randint(low=1, high=5, size=shape) - if isinstance(res, onp.ndarray): - return res.astype(onp.float32) - return float(res) - - -# return an random boolean array -def rand_bool(*shape): - return onp.random.rand(*shape) > 0.5 - - -def match_res(mnp_fn, onp_fn, *arrs, **kwargs): - """Checks results from applying mnp_fn and onp_fn on arrs respectively""" - mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) - mnp_res = mnp_fn(*mnp_arrs, **kwargs) - onp_res = onp_fn(*arrs, **kwargs) - match_all_arrays(mnp_res, onp_res) - - -def match_all_arrays(mnp_res, onp_res, error=0): - if isinstance(mnp_res, (tuple, list)): - assert len(mnp_res) == len(onp_res) - for actual, expected in zip(mnp_res, onp_res): - match_array(actual.asnumpy(), expected, error) - else: - match_array(mnp_res.asnumpy(), onp_res, error) - - -def match_meta(actual, expected): - # float64 and int64 are not supported, and the default type for - # float and int are float32 and int32, respectively - if expected.dtype == onp.float64: - expected = expected.astype(onp.float32) - elif expected.dtype == onp.int64: - expected = expected.astype(onp.int32) - assert actual.shape == expected.shape - assert actual.dtype == expected.dtype - - # Test np.transpose and np.ndarray.transpose def mnp_transpose(input_tensor): a = mnp.transpose(input_tensor, (0, 2, 1)) @@ -458,6 +386,34 @@ def test_concatenate(): check_all_results(o_concatenate, m_concatenate) +def mnp_append(arr1, arr2): + a = mnp.append(arr1, arr2) + b = mnp.append(arr1, arr2, axis=0) + c = mnp.append(arr1, arr2, axis=-1) + return a, b, c + +def onp_append(arr1, arr2): + a = onp.append(arr1, arr2) + b = onp.append(arr1, arr2, axis=0) + c = onp.append(arr1, arr2, axis=-1) + return a, b, c + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_append(): + onp_array = onp.random.random((4, 3, 2)).astype('float32') + onp_value = onp.random.random((4, 3, 2)).astype('float32') + mnp_array = mnp.asarray(onp_array) + mnp_value = mnp.asarray(onp_value) + onp_res = onp_append(onp_array, onp_value) + mnp_res = mnp_append(mnp_array, mnp_value) + check_all_results(onp_res, mnp_res) + + def construct_arrays(n=1, ndim=1, axis=None, low=1, high=5): onp_array_lst = [] mnp_array_lst = [] @@ -629,7 +585,7 @@ def onp_atleast3d(*arys): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_atleast1d(): - run_non_kw_test(mnp_atleast1d, onp_atleast1d) + run_non_kw_test(mnp_atleast1d, onp_atleast1d, Cases()) @pytest.mark.level1 @@ -639,7 +595,7 @@ def test_atleast1d(): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_atleast2d(): - run_non_kw_test(mnp_atleast2d, onp_atleast2d) + run_non_kw_test(mnp_atleast2d, onp_atleast2d, Cases()) @pytest.mark.level1 @@ -649,7 +605,7 @@ def test_atleast2d(): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_atleast3d(): - run_non_kw_test(mnp_atleast3d, onp_atleast3d) + run_non_kw_test(mnp_atleast3d, onp_atleast3d, Cases()) # Test np.where @@ -858,6 +814,444 @@ def test_stack(): match_res(mnp.stack, onp.stack, arrs, axis=i) +def mnp_roll(input_tensor): + a = mnp.roll(input_tensor, -3) + b = mnp.roll(input_tensor, [-2, -3], 1) + c = mnp.roll(input_tensor, (3, 0, -5), (-1, -2, 0)) + d = mnp.roll(input_tensor, (4,), [0, 0, 1]) + return a, b, c, d + + +def onp_roll(input_array): + a = onp.roll(input_array, -3) + b = onp.roll(input_array, [-2, -3], 1) + c = onp.roll(input_array, (3, 0, -5), (-1, -2, 0)) + d = onp.roll(input_array, (4,), [0, 0, 1]) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_roll(): + arr = rand_int(3, 4, 5) + match_res(mnp_roll, onp_roll, arr) + arr = rand_int(1, 4, 6).astype("int64") + match_res(mnp_roll, onp_roll, arr) + + +def mnp_moveaxis(a): + a = mnp.moveaxis(a, 3, 3) + b = mnp.moveaxis(a, -1, 4) + c = mnp.moveaxis(a, (2, 1, 4), (0, 3, 2)) + d = mnp.moveaxis(a, [-2, -5], [2, -4]) + return a, b, c, d + + +def onp_moveaxis(a): + a = onp.moveaxis(a, 3, 3) + b = onp.moveaxis(a, -1, 4) + c = onp.moveaxis(a, (2, 1, 4), (0, 3, 2)) + d = onp.moveaxis(a, [-2, -5], [2, -4]) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_moveaxis(): + a = rand_int(2, 4, 5, 9, 6) + match_res(mnp_moveaxis, onp_moveaxis, a) + a = rand_int(2, 4, 5, 0, 6, 7, 1, 3, 8) + match_res(mnp_moveaxis, onp_moveaxis, a) + + +def mnp_tile(x): + a = mnp.tile(x, 0) + b = mnp.tile(x, 1) + c = mnp.tile(x, 3) + d = mnp.tile(x, [5, 1]) + e = mnp.tile(x, (3, 1, 0)) + f = mnp.tile(x, [5, 1, 2, 3, 7]) + return a, b, c, d, e, f + + +def onp_tile(x): + a = onp.tile(x, 0) + b = onp.tile(x, 1) + c = onp.tile(x, 3) + d = onp.tile(x, [5, 1]) + e = onp.tile(x, (3, 1, 0)) + f = onp.tile(x, [5, 1, 2, 3, 7]) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_tile(): + a = rand_int(2, 3, 4) + match_res(mnp_tile, onp_tile, a) + b = rand_int(5, 0, 8) + match_res(mnp_tile, onp_tile, b) + + +def mnp_broadcast_to(x): + a = mnp.broadcast_to(x, (2, 3)) + b = mnp.broadcast_to(x, (8, 1, 3)) + return a, b + + +def onp_broadcast_to(x): + a = onp.broadcast_to(x, (2, 3)) + b = onp.broadcast_to(x, (8, 1, 3)) + return a, b + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_broadcast_to(): + x = rand_int() + match_res(mnp_broadcast_to, onp_broadcast_to, x) + x = rand_int(3) + match_res(mnp_broadcast_to, onp_broadcast_to, x) + x = rand_int(1, 3) + match_res(mnp_broadcast_to, onp_broadcast_to, x) + + +def mnp_broadcast_arrays(*args): + return mnp.broadcast_arrays(*args) + + +def onp_broadcast_arrays(*args): + return onp.broadcast_arrays(*args) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_broadcast_arrays(): + test_case = Cases() + broadcastables = test_case.broadcastables + for i in range(len(broadcastables)): + arrs = broadcastables[i:] + match_res(mnp_broadcast_arrays, onp_broadcast_arrays, *arrs) + + +def mnp_flip(x): + a = mnp.flip(x) + b = mnp.flip(x, 0) + c = mnp.flip(x, 1) + d = mnp.flip(x, (-3, -1)) + return a, b, c, d + + +def onp_flip(x): + a = onp.flip(x) + b = onp.flip(x, 0) + c = onp.flip(x, 1) + d = onp.flip(x, (-3, -1)) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_flip(): + x = rand_int(2, 3, 4) + run_multi_test(mnp_flip, onp_flip, (x,)) + + +def mnp_flipud(x): + return mnp.flipud(x) + + +def onp_flipud(x): + return onp.flipud(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_flipud(): + x = rand_int(2, 3, 4) + run_multi_test(mnp_flipud, onp_flipud, (x,)) + + +def mnp_fliplr(x): + return mnp.fliplr(x) + + +def onp_fliplr(x): + return onp.fliplr(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_fliplr(): + x = rand_int(2, 3, 4) + run_multi_test(mnp_fliplr, onp_fliplr, (x,)) + + +def mnp_split(input_tensor): + a = mnp.split(input_tensor, indices_or_sections=1) + b = mnp.split(input_tensor, indices_or_sections=3) + c = mnp.split(input_tensor, indices_or_sections=(-9, -8, 6)) + d = mnp.split(input_tensor, indices_or_sections=(3, 2, 1)) + e = mnp.split(input_tensor, indices_or_sections=(-10, -4, 5, 10)) + f = mnp.split(input_tensor, indices_or_sections=[0, 2], axis=1) + return a, b, c, d, e, f + + +def onp_split(input_array): + a = onp.split(input_array, indices_or_sections=1) + b = onp.split(input_array, indices_or_sections=3) + c = onp.split(input_array, indices_or_sections=(-9, -8, 6)) + d = onp.split(input_array, indices_or_sections=(3, 2, 1)) + e = onp.split(input_array, indices_or_sections=(-10, -4, 5, 10)) + f = onp.split(input_array, indices_or_sections=[0, 2], axis=1) + return a, b, c, d, e, f + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_split(): + onp_arrs = [ + onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'), + onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64') + ] + mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): + o_split = onp_split(onp_arr) + m_split = mnp_split(mnp_arr) + for expect_lst, actual_lst in zip(o_split, m_split): + for expect, actual in zip(expect_lst, actual_lst): + match_array(expect, actual.asnumpy()) + + +def mnp_vsplit(input_tensor): + a = mnp.vsplit(input_tensor, indices_or_sections=3) + b = mnp.vsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) + c = mnp.vsplit(input_tensor, indices_or_sections=[0, 2]) + return a, b, c + + +def onp_vsplit(input_array): + a = onp.vsplit(input_array, indices_or_sections=3) + b = onp.vsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) + c = onp.vsplit(input_array, indices_or_sections=[0, 2]) + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_vsplit(): + onp_arrs = [ + onp.random.randint(1, 5, size=(9, 4, 5)).astype('float32'), + onp.random.randint(1, 5, size=(9, 4, 5)).astype('float64') + ] + mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): + o_vsplit = onp_vsplit(onp_arr) + m_vsplit = mnp_vsplit(mnp_arr) + for expect_lst, actual_lst in zip(o_vsplit, m_vsplit): + for expect, actual in zip(expect_lst, actual_lst): + match_array(expect, actual.asnumpy()) + + +def mnp_hsplit(input_tensor): + a = mnp.hsplit(input_tensor, indices_or_sections=3) + b = mnp.hsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) + c = mnp.hsplit(input_tensor, indices_or_sections=[0, 2]) + return a, b, c + + +def onp_hsplit(input_array): + a = onp.hsplit(input_array, indices_or_sections=3) + b = onp.hsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) + c = onp.hsplit(input_array, indices_or_sections=[0, 2]) + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_hsplit(): + onp_arrs = [ + onp.random.randint(1, 5, size=(4, 9, 5)).astype('float32'), + onp.random.randint(1, 5, size=(4, 9, 5)).astype('float64') + ] + mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): + o_hsplit = onp_hsplit(onp_arr) + m_hsplit = mnp_hsplit(mnp_arr) + for expect_lst, actual_lst in zip(o_hsplit, m_hsplit): + for expect, actual in zip(expect_lst, actual_lst): + match_array(expect, actual.asnumpy()) + + +def mnp_dsplit(input_tensor): + a = mnp.dsplit(input_tensor, indices_or_sections=3) + b = mnp.dsplit(input_tensor, indices_or_sections=(-10, -4, 5, 10)) + c = mnp.dsplit(input_tensor, indices_or_sections=[0, 2]) + return a, b, c + + +def onp_dsplit(input_array): + a = onp.dsplit(input_array, indices_or_sections=3) + b = onp.dsplit(input_array, indices_or_sections=(-10, -4, 5, 10)) + c = onp.dsplit(input_array, indices_or_sections=[0, 2]) + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_dsplit(): + onp_arrs = [ + onp.random.randint(1, 5, size=(5, 4, 9)).astype('float32'), + onp.random.randint(1, 5, size=(5, 4, 9)).astype('float64') + ] + mnp_arrs = [mnp.asarray(arr) for arr in onp_arrs] + for onp_arr, mnp_arr in zip(onp_arrs, mnp_arrs): + o_dsplit = onp_dsplit(onp_arr) + m_dsplit = mnp_dsplit(mnp_arr) + for expect_lst, actual_lst in zip(o_dsplit, m_dsplit): + for expect, actual in zip(expect_lst, actual_lst): + match_array(expect, actual.asnumpy()) + + +def mnp_take_along_axis(*arrs): + x = arrs[0] + a = mnp.take_along_axis(x, arrs[1], axis=None) + b = mnp.take_along_axis(x, arrs[2], axis=1) + c = mnp.take_along_axis(x, arrs[3], axis=-1) + d = mnp.take_along_axis(x, arrs[4], axis=0) + return a, b, c, d + + +def onp_take_along_axis(*arrs): + x = arrs[0] + a = onp.take_along_axis(x, arrs[1], axis=None) + b = onp.take_along_axis(x, arrs[2], axis=1) + c = onp.take_along_axis(x, arrs[3], axis=-1) + d = onp.take_along_axis(x, arrs[4], axis=0) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_take_along_axis(): + x = rand_int(6, 7, 8, 9) + indices1 = rand_int(2).astype(onp.int32) + indices2 = rand_int(6, 3, 8, 1).astype(onp.int32) + indices3 = rand_int(6, 1, 8, 5).astype(onp.int32) + indices4 = rand_int(4, 1, 1, 1).astype(onp.int32) + run_multi_test(mnp_take_along_axis, onp_take_along_axis, + (x, indices1, indices2, indices3, indices4)) + + +def mnp_take(x, indices): + a = mnp.take(x, indices) + b = mnp.take(x, indices, axis=-1) + c = mnp.take(x, indices, axis=0, mode='wrap') + d = mnp.take(x, indices, axis=1, mode='clip') + return a, b, c, d + + +def onp_take(x, indices): + a = onp.take(x, indices) + b = onp.take(x, indices, axis=-1) + c = onp.take(x, indices, axis=0, mode='wrap') + d = onp.take(x, indices, axis=1, mode='clip') + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_take(): + x = rand_int(2, 3, 4, 5) + indices = rand_int(2, 3).astype(onp.int32) + run_multi_test(mnp_take, onp_take, (x, indices)) + + +def mnp_repeat(x): + a = mnp.repeat(x, 2) + b = mnp.repeat(x, 3, axis=0) + c = mnp.repeat(x, (4, 1, 5), axis=1) + d = mnp.repeat(x, (3, 2, 1, 0, 4), axis=-1) + e = mnp.repeat(x, 0) + return a, b, c, d, e + + +def onp_repeat(x): + a = onp.repeat(x, 2) + b = onp.repeat(x, 3, axis=0) + c = onp.repeat(x, (4, 1, 5), axis=1) + d = onp.repeat(x, (3, 2, 1, 0, 4), axis=-1) + e = onp.repeat(x, 0) + return a, b, c, d, e + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_repeat(): + x = rand_int(2, 3, 4, 5) + run_multi_test(mnp_repeat, onp_repeat, (x,)) + + class ReshapeExpandSqueeze(Cell): def __init__(self): super(ReshapeExpandSqueeze, self).__init__() diff --git a/tests/st/numpy_native/test_logic_ops.py b/tests/st/numpy_native/test_logic_ops.py new file mode 100644 index 0000000000..14fa603aa0 --- /dev/null +++ b/tests/st/numpy_native/test_logic_ops.py @@ -0,0 +1,263 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""unit tests for numpy logical operations""" + +import pytest +import numpy as onp + +import mindspore.numpy as mnp + +from .utils import rand_int, run_binop_test, match_res + + +class Cases(): + def __init__(self): + self.arrs = [ + rand_int(2), + rand_int(2, 3), + rand_int(2, 3, 4), + rand_int(2, 3, 4, 5), + ] + + # scalars expanded across the 0th dimension + self.scalars = [ + rand_int(), + rand_int(1), + rand_int(1, 1), + rand_int(1, 1, 1, 1), + ] + + # arrays of the same size expanded across the 0th dimension + self.expanded_arrs = [ + rand_int(2, 3), + rand_int(1, 2, 3), + rand_int(1, 1, 2, 3), + rand_int(1, 1, 1, 2, 3), + ] + + # arrays which can be broadcast + self.broadcastables = [ + rand_int(5), + rand_int(6, 1), + rand_int(7, 1, 5), + rand_int(8, 1, 6, 1) + ] + + # array which contains infs and nans + self.infs = onp.array([[1.0, onp.nan], [onp.inf, onp.NINF], [2.3, -4.5], [onp.nan, 0.0]]) + + +test_case = Cases() + + +def mnp_not_equal(a, b): + return mnp.not_equal(a, b) + + +def onp_not_equal(a, b): + return onp.not_equal(a, b) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_not_equal(): + run_binop_test(mnp_not_equal, onp_not_equal, test_case) + + +def mnp_less_equal(a, b): + return mnp.less_equal(a, b) + + +def onp_less_equal(a, b): + return onp.less_equal(a, b) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_less_equal(): + run_binop_test(mnp_less_equal, onp_less_equal, test_case) + + +def mnp_less(a, b): + return mnp.less(a, b) + + +def onp_less(a, b): + return onp.less(a, b) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_less(): + run_binop_test(mnp_less, onp_less, test_case) + + +def mnp_greater_equal(a, b): + return mnp.greater_equal(a, b) + + +def onp_greater_equal(a, b): + return onp.greater_equal(a, b) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_greater_equal(): + run_binop_test(mnp_greater_equal, onp_greater_equal, test_case) + + +def mnp_greater(a, b): + return mnp.greater(a, b) + + +def onp_greater(a, b): + return onp.greater(a, b) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_greater(): + run_binop_test(mnp_greater, onp_greater, test_case) + + +def mnp_equal(a, b): + return mnp.equal(a, b) + + +def onp_equal(a, b): + return onp.equal(a, b) + + +def test_equal(): + run_binop_test(mnp_equal, onp_equal, test_case) + + +def mnp_isfinite(x): + return mnp.isfinite(x) + + +def onp_isfinite(x): + return onp.isfinite(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isfinite(): + match_res(mnp_isfinite, onp_isfinite, test_case.infs) + + +def mnp_isnan(x): + return mnp.isnan(x) + + +def onp_isnan(x): + return onp.isnan(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isnan(): + match_res(mnp_isnan, onp_isnan, test_case.infs) + + +def mnp_isinf(x): + return mnp.isinf(x) + + +def onp_isinf(x): + return onp.isinf(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isinf(): + match_res(mnp_isinf, onp_isinf, test_case.infs) + + +def mnp_isposinf(x): + return mnp.isposinf(x) + + +def onp_isposinf(x): + return onp.isposinf(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isposinf(): + match_res(mnp_isposinf, onp_isposinf, test_case.infs) + + +def mnp_isneginf(x): + return mnp.isneginf(x) + + +def onp_isneginf(x): + return onp.isneginf(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_isneginf(): + match_res(mnp_isneginf, onp_isneginf, test_case.infs) + + +def test_isscalar(): + assert mnp.isscalar(1) == onp.isscalar(1) + assert mnp.isscalar(2.3) == onp.isscalar(2.3) + assert mnp.isscalar([4.5]) == onp.isscalar([4.5]) + assert mnp.isscalar(False) == onp.isscalar(False) + assert mnp.isscalar(mnp.array(True)) == onp.isscalar(onp.array(True)) + assert mnp.isscalar('numpy') == onp.isscalar('numpy') diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py index e72ac0fcbe..552f0a0d7e 100644 --- a/tests/st/numpy_native/test_math_ops.py +++ b/tests/st/numpy_native/test_math_ops.py @@ -13,32 +13,17 @@ # limitations under the License. # ============================================================================ """unit tests for numpy math operations""" -from functools import partial import pytest import numpy as onp import mindspore.numpy as mnp -from mindspore import context - - -def rand_int(*shape): - """return an random integer array with parameter shape""" - res = onp.random.randint(low=1, high=5, size=shape) - if isinstance(res, onp.ndarray): - return res.astype(onp.float32) - return float(res) - - -# return an random boolean array -def rand_bool(*shape): - return onp.random.rand(*shape) > 0.5 +from .utils import rand_int, rand_bool, run_binop_test, run_unary_test, run_multi_test, \ + run_single_test, match_res, match_array class Cases(): def __init__(self): - self.device_cpu = context.get_context('device_target') - self.arrs = [ rand_int(2), rand_int(2, 3), @@ -70,14 +55,6 @@ class Cases(): rand_int(1, 1, 1, 2, 3), ] - # arrays of the same size expanded across the 0th dimension - self.expanded_arrs = [ - rand_int(2, 3), - rand_int(1, 2, 3), - rand_int(1, 1, 2, 3), - rand_int(1, 1, 1, 2, 3), - ] - # arrays with last dimension aligned self.aligned_arrs = [ rand_int(2, 3), @@ -135,7 +112,6 @@ class Cases(): test_case = Cases() -context.set_context(mode=context.GRAPH_MODE, device_target='CPU') def mnp_add(x1, x2): @@ -170,6 +146,14 @@ def onp_divide(x1, x2): return onp.divide(x1, x2) +def mnp_true_divide(x1, x2): + return mnp.true_divide(x1, x2) + + +def onp_true_divide(x1, x2): + return onp.true_divide(x1, x2) + + def mnp_power(x1, x2): return mnp.power(x1, x2) @@ -178,68 +162,123 @@ def onp_power(x1, x2): return onp.power(x1, x2) -def mnp_inner(a, b): - return mnp.inner(a, b) - - -def onp_inner(a, b): - return onp.inner(a, b) +def mnp_float_power(x1, x2): + return mnp.float_power(x1, x2) -def mnp_dot(a, b): - return mnp.dot(a, b) +def onp_float_power(x1, x2): + return onp.float_power(x1, x2) -def onp_dot(a, b): - return onp.dot(a, b) +def mnp_minimum(a, b): + return mnp.minimum(a, b) -def mnp_outer(a, b): - return mnp.outer(a, b) +def onp_minimum(a, b): + return onp.minimum(a, b) -def onp_outer(a, b): - return onp.outer(a, b) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_add(): + run_binop_test(mnp_add, onp_add, test_case) -def mnp_add_kwargs(x, y, where=None, out=None): - return mnp.add(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_subtract(): + run_binop_test(mnp_subtract, onp_subtract, test_case) -def onp_add_kwargs(x, y, where=None, out=None): - return onp.add(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_multiply(): + run_binop_test(mnp_mutiply, onp_multiply, test_case) -def mnp_subtract_kwargs(x, y, where=None, out=None): - return mnp.subtract(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_divide(): + run_binop_test(mnp_divide, onp_divide, test_case) -def onp_subtract_kwargs(x, y, where=None, out=None): - return onp.subtract(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_true_divide(): + run_binop_test(mnp_true_divide, onp_true_divide, test_case) -def mnp_multiply_kwargs(x, y, where=None, out=None): - return mnp.multiply(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_power(): + run_binop_test(mnp_power, onp_power, test_case) -def onp_multiply_kwargs(x, y, where=None, out=None): - return onp.multiply(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_float_power(): + run_binop_test(mnp_float_power, onp_float_power, test_case) -def mnp_divide_kwargs(x, y, where=None, out=None): - return mnp.divide(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_minimum(): + run_binop_test(mnp_minimum, onp_minimum, test_case) -def onp_divide_kwargs(x, y, where=None, out=None): - return onp.divide(x, y, where=where, out=out) +def mnp_add_kwargs(x, y, where=None, out=None): + return mnp.add(x, y, where=where, out=out) -def mnp_power_kwargs(x, y, where=None, out=None): - return mnp.power(x, y, where=where, out=out) +def onp_add_kwargs(x, y, where=None, out=None): + return onp.add(x, y, where=where, out=out) -def onp_power_kwargs(x, y, where=None, out=None): - return onp.power(x, y, where=where, out=out) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_add_kwargs(): + for where in test_case.bool_broadcastables[:2]: + for x in test_case.broadcastables[:2]: + for y in test_case.broadcastables[:2]: + shape_out = onp.broadcast(where, x, y).shape + out = rand_int(*shape_out) + match_res(mnp_add_kwargs, onp_add_kwargs, x, y, where, out) def mnp_tensordot(x, y): @@ -266,41 +305,42 @@ def onp_tensordot(x, y): return a, b, c, d, e, f, g, h -def run_binop_test(mnp_fn, onp_fn): - for arr in test_case.arrs: - match_res(mnp_fn, onp_fn, arr, arr) - - for scalar in test_case.scalars: - match_res(mnp_fn, onp_fn, arr, scalar) - match_res(mnp_fn, onp_fn, scalar, arr) - - for scalar1 in test_case.scalars: - for scalar2 in test_case.scalars: - match_res(mnp_fn, onp_fn, scalar1, scalar2) - - for expanded_arr1 in test_case.expanded_arrs: - for expanded_arr2 in test_case.expanded_arrs: - match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2) - - for broadcastable1 in test_case.broadcastables: - for broadcastable2 in test_case.broadcastables: - match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2) - - -def run_multi_test(mnp_fn, onp_fn, arrs): - mnp_arrs = map(mnp.asarray, arrs) - for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): - match_array(actual.asnumpy(), expected) - - @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_add(): - run_binop_test(mnp_add, onp_add) +def test_tensordot(): + x = rand_int(4, 2, 7, 7) + y = rand_int(7, 7, 6) + run_multi_test(mnp_tensordot, onp_tensordot, (x, y)) + + +def mnp_std(x): + a = mnp.std(x) + b = mnp.std(x, axis=None) + c = mnp.std(x, axis=0) + d = mnp.std(x, axis=1) + e = mnp.std(x, axis=(-1, 1)) + f = mnp.std(x, axis=(0, 1, 2)) + g = mnp.std(x, axis=None, ddof=1, keepdims=True) + h = mnp.std(x, axis=0, ddof=1, keepdims=True) + i = mnp.std(x, axis=(2), ddof=1, keepdims=True) + return a, b, c, d, e, f, g, h, i + + +def onp_std(x): + a = onp.std(x) + b = onp.std(x, axis=None) + c = onp.std(x, axis=0) + d = onp.std(x, axis=1) + e = onp.std(x, axis=(-1, 1)) + f = onp.std(x, axis=(0, 1, 2)) + g = onp.std(x, axis=None, ddof=1, keepdims=True) + h = onp.std(x, axis=0, ddof=1, keepdims=True) + i = onp.std(x, axis=(2), ddof=1, keepdims=True) + return a, b, c, d, e, f, g, h, i @pytest.mark.level1 @@ -309,8 +349,29 @@ def test_add(): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_subtract(): - run_binop_test(mnp_subtract, onp_subtract) +def test_std(): + arr1 = rand_int(2, 3, 4, 5) + arr2 = rand_int(4, 5, 4, 3, 3) + run_single_test(mnp_std, onp_std, arr1, error=1e-5) + run_single_test(mnp_std, onp_std, arr2, error=1e-5) + + +def mnp_var(x): + a = mnp.std(x) + b = mnp.std(x, axis=0) + c = mnp.std(x, axis=(0)) + d = mnp.std(x, axis=(0, 1, 2)) + e = mnp.std(x, axis=(-1, 1, 2), ddof=1, keepdims=True) + return a, b, c, d, e + + +def onp_var(x): + a = onp.std(x) + b = onp.std(x, axis=0) + c = onp.std(x, axis=(0)) + d = onp.std(x, axis=(0, 1, 2)) + e = onp.std(x, axis=(-1, 1, 2), ddof=1, keepdims=True) + return a, b, c, d, e @pytest.mark.level1 @@ -319,8 +380,37 @@ def test_subtract(): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_multiply(): - run_binop_test(mnp_mutiply, onp_multiply) +def test_var(): + arr1 = rand_int(2, 3, 4, 5) + arr2 = rand_int(4, 5, 4, 3, 3) + run_single_test(mnp_var, onp_var, arr1, error=1e-5) + run_single_test(mnp_var, onp_var, arr2, error=1e-5) + + +def mnp_average(x): + a = mnp.average(x) + b = mnp.average(x, axis=None) + c = mnp.average(x, axis=0) + d = mnp.average(x, axis=1) + e = mnp.average(x, axis=(-2, 1)) + f = mnp.average(x, axis=(0, 1, 2, 3)) + g = mnp.average(x, axis=None, weights=x) + h = mnp.average(x, axis=0, weights=x) + i = mnp.average(x, axis=(1, 2, 3), weights=x) + return a, b, c, d, e, f, g, h, i + + +def onp_average(x): + a = onp.average(x) + b = onp.average(x, axis=None) + c = onp.average(x, axis=0) + d = onp.average(x, axis=1) + e = onp.average(x, axis=(-2, 1)) + f = onp.average(x, axis=(0, 1, 2, 3)) + g = onp.average(x, axis=None, weights=x) + h = onp.average(x, axis=0, weights=x) + i = onp.average(x, axis=(1, 2, 3), weights=x) + return a, b, c, d, e, f, g, h, i @pytest.mark.level1 @@ -329,8 +419,31 @@ def test_multiply(): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_divide(): - run_binop_test(mnp_divide, onp_divide) +def test_average(): + arr1 = rand_int(2, 3, 4, 5) + arr2 = rand_int(4, 5, 1, 3, 1) + run_single_test(mnp_average, onp_average, arr1) + run_single_test(mnp_average, onp_average, arr2) + + +def mnp_count_nonzero(x): + a = mnp.count_nonzero(x) + b = mnp.count_nonzero(x, axis=None) + c = mnp.count_nonzero(x, axis=0) + d = mnp.count_nonzero(x, axis=1) + e = mnp.count_nonzero(x, axis=(-2, 1)) + f = mnp.count_nonzero(x, axis=(0, 1, 2, 3)) + return a, b, c, d, e, f + + +def onp_count_nonzero(x): + a = onp.count_nonzero(x) + b = onp.count_nonzero(x, axis=None) + c = onp.count_nonzero(x, axis=0) + d = onp.count_nonzero(x, axis=1) + e = onp.count_nonzero(x, axis=(-2, 1)) + f = onp.count_nonzero(x, axis=(0, 1, 2, 3)) + return a, b, c, d, e, f @pytest.mark.level1 @@ -339,8 +452,20 @@ def test_divide(): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_power(): - run_binop_test(mnp_power, onp_power) +def test_count_nonzero(): + # minus 5 to make some values below zero + arr1 = rand_int(2, 3, 4, 5) - 5 + arr2 = rand_int(4, 5, 4, 3, 3) - 5 + run_single_test(mnp_count_nonzero, onp_count_nonzero, arr1) + run_single_test(mnp_count_nonzero, onp_count_nonzero, arr2) + + +def mnp_inner(a, b): + return mnp.inner(a, b) + + +def onp_inner(a, b): + return onp.inner(a, b) @pytest.mark.level1 @@ -360,6 +485,14 @@ def test_inner(): scalar1, scalar2) +def mnp_dot(a, b): + return mnp.dot(a, b) + + +def onp_dot(a, b): + return onp.dot(a, b) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -388,29 +521,12 @@ def test_dot(): test_case.core_broadcastables[2*i], test_case.core_broadcastables[2*i + 1]) -@pytest.mark.level1 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_x86_gpu_training -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_outer(): - run_binop_test(mnp_outer, onp_outer) +def mnp_outer(a, b): + return mnp.outer(a, b) -@pytest.mark.level1 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_x86_gpu_training -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_add_kwargs(): - for where in test_case.bool_broadcastables[:2]: - for x in test_case.broadcastables[:2]: - for y in test_case.broadcastables[:2]: - shape_out = onp.broadcast(where, x, y).shape - out = rand_int(*shape_out) - match_res(mnp_add_kwargs, onp_add_kwargs, x, y, where, out) +def onp_outer(a, b): + return onp.outer(a, b) @pytest.mark.level1 @@ -419,10 +535,8 @@ def test_add_kwargs(): @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_tensordot(): - x = rand_int(4, 2, 7, 7) - y = rand_int(7, 7, 6) - run_multi_test(mnp_tensordot, onp_tensordot, (x, y)) +def test_outer(): + run_binop_test(mnp_outer, onp_outer, test_case) @pytest.mark.level1 @@ -476,6 +590,47 @@ def test_absolute(): onp.absolute(a.asnumpy(), out=out, where=where)) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_deg2rad_rad2deg(): + arrs = [rand_int(2, 3), rand_int(1, 2, 4), rand_int(2, 4)] + for arr in arrs: + match_res(mnp.deg2rad, onp.deg2rad, arr) + match_res(mnp.rad2deg, onp.rad2deg, arr) + + +def mnp_ptp(x): + a = mnp.ptp(x) + b = mnp.ptp(x, keepdims=True) + c = mnp.ptp(x, axis=(0, 1)) + d = mnp.ptp(x, axis=-1) + return a, b, c, d + + +def onp_ptp(x): + a = onp.ptp(x) + b = onp.ptp(x, keepdims=True) + c = onp.ptp(x, axis=(0, 1)) + d = onp.ptp(x, axis=-1) + return a, b, c, d + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_ptp(): + arrs = [rand_int(2, 3), rand_int(1, 2, 4), rand_int(2, 4)] + for arr in arrs: + match_res(mnp_ptp, onp_ptp, arr) + + def mnp_add_dtype(x1, x2, out, where): a = mnp.add(x1, x2, dtype=mnp.float16) b = mnp.add(x1, x2, out=out, dtype=mnp.float16) @@ -511,26 +666,478 @@ def test_add_dtype(): assert actual.asnumpy().dtype == expected.dtype -# check if the output from mnp function and onp function applied on the arrays are matched +def mnp_matmul(x1, x2): + return mnp.matmul(x1, x2) -def match_res(mnp_fn, onp_fn, *arrs): - mnp_arrs = map(partial(mnp.asarray, dtype='float32'), arrs) - mnp_res = mnp_fn(*mnp_arrs) - onp_res = onp_fn(*arrs) - if isinstance(mnp_res, (tuple, list)): - for actual, expected in zip(mnp_res, onp_res): - match_array(actual.asnumpy(), expected) - else: - match_array(mnp_res.asnumpy(), onp_res) +def onp_matmul(x1, x2): + return onp.matmul(x1, x2) -def match_array(actual, expected, error=5): - if error > 0: - onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), - decimal=error) - else: - onp.testing.assert_equal(actual.tolist(), expected.tolist()) +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_matmul(): + for scalar1 in test_case.scalars[1:]: + for scalar2 in test_case.scalars[1:]: + match_res(mnp_matmul, onp_matmul, + scalar1, scalar2) + for i in range(8): + match_res(mnp_matmul, onp_matmul, + test_case.core_broadcastables[2*i], + test_case.core_broadcastables[2*i + 1]) + + +def mnp_square(x): + return mnp.square(x) + + +def onp_square(x): + return onp.square(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_square(): + run_unary_test(mnp_square, onp_square, test_case) + + +def mnp_sqrt(x): + return mnp.sqrt(x) + + +def onp_sqrt(x): + return onp.sqrt(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_sqrt(): + run_unary_test(mnp_sqrt, onp_sqrt, test_case) + + +def mnp_reciprocal(x): + return mnp.reciprocal(x) + + +def onp_reciprocal(x): + return onp.reciprocal(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_reciprocal(): + run_unary_test(mnp_reciprocal, onp_reciprocal, test_case) + + +def mnp_log(x): + return mnp.log(x) + + +def onp_log(x): + return onp.log(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_log(): + run_unary_test(mnp.log, onp.log, test_case, error=1e-5) + + +def mnp_maximum(x1, x2): + return mnp.maximum(x1, x2) + + +def onp_maximum(x1, x2): + return onp.maximum(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_maximum(): + run_binop_test(mnp_maximum, onp_maximum, test_case) + + +def mnp_clip(x): + a = mnp.clip(x, mnp.asarray(10.0), mnp.asarray([2,])) + b = mnp.clip(x, 0, 1) + c = mnp.clip(x, mnp.asarray(0), mnp.asarray(10), dtype=mnp.float64) + return a, b, c + + +def onp_clip(x): + a = onp.clip(x, onp.asarray(10.0), onp.asarray([2,])) + b = onp.clip(x, 0, 1) + c = onp.clip(x, onp.asarray(0), onp.asarray(10), dtype=onp.float64) + return a, b, c + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_clip(): + run_unary_test(mnp_clip, onp_clip, test_case) + + +def mnp_amax(x, mask): + a = mnp.amax(x) + b = mnp.amax(x, axis=-3) + c = mnp.amax(x, keepdims=True) + d = mnp.amax(x, initial=3) + e = mnp.amax(x, axis=(0, 1), keepdims=True) + f = mnp.amax(x, initial=4, where=mask) + g = mnp.amax(x, initial=5, where=mask, keepdims=True) + h = mnp.amax(x, axis=(1, 2, 3), initial=6, where=mask) + return a, b, c, d, e, f, g, h + + +def onp_amax(x, mask): + a = onp.amax(x) + b = onp.amax(x, axis=-3) + c = onp.amax(x, keepdims=True) + d = onp.amax(x, initial=3) + e = onp.amax(x, axis=(0, 1), keepdims=True) + f = onp.amax(x, initial=4, where=mask) + g = onp.amax(x, initial=5, where=mask, keepdims=True) + h = onp.amax(x, axis=(1, 2, 3), initial=6, where=mask) + return a, b, c, d, e, f, g, h + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_amax(): + a = rand_int(2, 3, 4, 5).astype('float32') + mask = rand_bool(2, 3, 4, 5) + run_multi_test(mnp_amax, onp_amax, (a, mask)) + + +def mnp_amin(x, mask): + a = mnp.amin(x) + b = mnp.amin(x, axis=-3) + c = mnp.amin(x, keepdims=True) + d = mnp.amin(x, initial=-1) + e = mnp.amin(x, axis=(0, 1), keepdims=True) + f = mnp.amin(x, initial=-2, where=mask) + g = mnp.amin(x, initial=-3, where=mask, keepdims=True) + h = mnp.amin(x, axis=(1, 2, 3), initial=-4, where=mask) + return a, b, c, d, e, f, g, h + + +def onp_amin(x, mask): + a = onp.amin(x) + b = onp.amin(x, axis=-3) + c = onp.amin(x, keepdims=True) + d = onp.amin(x, initial=-1) + e = onp.amin(x, axis=(0, 1), keepdims=True) + f = onp.amin(x, initial=-2, where=mask) + g = onp.amin(x, initial=-3, where=mask, keepdims=True) + h = onp.amin(x, axis=(1, 2, 3), initial=-4, where=mask) + return a, b, c, d, e, f, g, h + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_amin(): + a = rand_int(2, 3, 4, 5).astype('float32') + mask = rand_bool(2, 3, 4, 5) + run_multi_test(mnp_amin, onp_amin, (a, mask)) + + +def mnp_hypot(x1, x2): + return mnp.hypot(x1, x2) + + +def onp_hypot(x1, x2): + return onp.hypot(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_hypot(): + run_binop_test(mnp_hypot, onp_hypot, test_case) + + +def mnp_heaviside(x1, x2): + return mnp.heaviside(x1, x2) + + +def onp_heaviside(x1, x2): + return onp.heaviside(x1, x2) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_heaviside(): + broadcastables = test_case.broadcastables + for b1 in broadcastables: + for b2 in broadcastables: + b = onp.subtract(b1, b2) + match_res(mnp_heaviside, onp_heaviside, b, b1) + match_res(mnp_heaviside, onp_heaviside, b, b2) + + +def mnp_floor(x): + return mnp.floor(x) + + +def onp_floor(x): + return onp.floor(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_floor(): + run_unary_test(mnp_floor, onp_floor, test_case) + x = rand_int(2, 3) * onp.random.rand(2, 3) + match_res(mnp_floor, onp_floor, x) + match_res(mnp_floor, onp_floor, -x) + + +def mnp_floor_divide(x, y): + return mnp.floor_divide(x, y) + + +def onp_floor_divde(x, y): + return onp.floor_divide(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_floor_divide(): + run_binop_test(mnp_floor_divide, onp_floor_divde, test_case) + + +def mnp_remainder(x, y): + return mnp.remainder(x, y) + + +def onp_remainder(x, y): + return onp.remainder(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_remainder(): + run_binop_test(mnp_remainder, onp_remainder, test_case) + + +def mnp_mod(x, y): + return mnp.mod(x, y) + + +def onp_mod(x, y): + return onp.mod(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_mod(): + run_binop_test(mnp_mod, onp_mod, test_case) + + +def mnp_fmod(x, y): + return mnp.fmod(x, y) + + +def onp_fmod(x, y): + return onp.fmod(x, y) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_fmod(): + run_binop_test(mnp_fmod, onp_fmod, test_case) + + +def mnp_fix(x): + return mnp.fix(x) + + +def onp_fix(x): + return onp.fix(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_fix(): + x = rand_int(2, 3) + y = rand_int(2, 3) + floats = onp.divide(onp.subtract(x, y), y) + match_res(mnp_fix, onp_fix, floats) + run_binop_test(mnp_fmod, onp_fmod, test_case) + + +def mnp_trunc(x): + return mnp.trunc(x) + + +def onp_trunc(x): + return onp.trunc(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_trunc(): + x = rand_int(2, 3) + y = rand_int(2, 3) + floats = onp.divide(onp.subtract(x, y), y) + match_res(mnp_trunc, onp_trunc, floats) + + +def mnp_exp(x): + return mnp.exp(x) + + +def onp_exp(x): + return onp.exp(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_exp(): + run_unary_test(mnp_exp, onp_exp, test_case, error=5) + + +def mnp_expm1(x): + return mnp.expm1(x) + + +def onp_expm1(x): + return onp.expm1(x) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_expm1(): + run_unary_test(mnp_expm1, onp_expm1, test_case, error=5) + + +def mnp_positive(x, out, where): + return mnp.positive(x, out=out, where=where) + + +def onp_positive(x, out, where): + return onp.positive(x, out=out, where=where) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_positive(): + arr = onp.arange(-6, 6).reshape((2, 2, 3)).astype('float32') + out_lst = [onp.ones((2, 2, 3)).astype('float32'), onp.ones((5, 2, 2, 3)).astype('float32')] + where_lst = [onp.full((2, 2, 3), [True, False, True]), onp.full((2, 3), False)] + for out in out_lst: + for where in where_lst: + onp_pos = onp_positive(arr, out=out, where=where) + mnp_pos = mnp_positive(mnp.asarray(arr), mnp.asarray(out), mnp.asarray(where)) + match_array(mnp_pos.asnumpy(), onp_pos) + + +def mnp_negative(x, out, where): + return mnp.negative(x, out=out, where=where) + + +def onp_negative(x, out, where): + return onp.negative(x, out=out, where=where) + + +@pytest.mark.level1 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_negative(): + arr = onp.arange(-6, 6).reshape((2, 2, 3)).astype('float32') + out_lst = [onp.ones((2, 2, 3)).astype('float32'), onp.ones((5, 2, 2, 3)).astype('float32')] + where_lst = [onp.full((2, 2, 3), [True, False, True]), onp.full((2, 3), False)] + for out in out_lst: + for where in where_lst: + onp_neg = onp_negative(arr, out=out, where=where) + mnp_neg = mnp_negative(mnp.asarray(arr), mnp.asarray(out), mnp.asarray(where)) + match_array(mnp_neg.asnumpy(), onp_neg) @pytest.mark.level1 diff --git a/tests/st/numpy_native/utils.py b/tests/st/numpy_native/utils.py new file mode 100644 index 0000000000..6a97e73a04 --- /dev/null +++ b/tests/st/numpy_native/utils.py @@ -0,0 +1,165 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""utility functions for mindspore.numpy st tests""" +import functools +import numpy as onp +import mindspore.numpy as mnp + + +def match_array(actual, expected, error=0): + + if isinstance(actual, int): + actual = onp.asarray(actual) + + if isinstance(expected, int): + expected = onp.asarray(expected) + + if error > 0: + onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(), + decimal=error) + else: + onp.testing.assert_equal(actual.tolist(), expected.tolist()) + + +def check_all_results(onp_results, mnp_results, error=0): + """Check all results from numpy and mindspore.numpy""" + for i, _ in enumerate(onp_results): + match_array(onp_results[i], mnp_results[i].asnumpy()) + + +def check_all_unique_results(onp_results, mnp_results): + """ + Check all results from numpy and mindspore.numpy. + + Args: + onp_results (Union[tuple of numpy.arrays, numpy.array]) + mnp_results (Union[tuple of Tensors, Tensor]) + """ + for i, _ in enumerate(onp_results): + if isinstance(onp_results[i], tuple): + for j in range(len(onp_results[i])): + match_array(onp_results[i][j], + mnp_results[i][j].asnumpy(), error=7) + else: + match_array(onp_results[i], mnp_results[i].asnumpy(), error=7) + + +def run_non_kw_test(mnp_fn, onp_fn, test_case): + """Run tests on functions with non keyword arguments""" + for i in range(len(test_case.arrs)): + arrs = test_case.arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.scalars)): + arrs = test_case.scalars[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.expanded_arrs)): + arrs = test_case.expanded_arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + for i in range(len(test_case.nested_arrs)): + arrs = test_case.nested_arrs[:i] + match_res(mnp_fn, onp_fn, *arrs) + + +def rand_int(*shape): + """return an random integer array with parameter shape""" + res = onp.random.randint(low=1, high=5, size=shape) + if isinstance(res, onp.ndarray): + return res.astype(onp.float32) + return float(res) + + +# return an random boolean array +def rand_bool(*shape): + return onp.random.rand(*shape) > 0.5 + + +def match_res(mnp_fn, onp_fn, *arrs, **kwargs): + """Checks results from applying mnp_fn and onp_fn on arrs respectively""" + mnp_arrs = map(functools.partial(mnp.asarray, dtype='float32'), arrs) + error = kwargs.get('error', 0) + kwargs.pop('error', None) + mnp_res = mnp_fn(*mnp_arrs, **kwargs) + onp_res = onp_fn(*arrs, **kwargs) + match_all_arrays(mnp_res, onp_res, error=error) + + +def match_all_arrays(mnp_res, onp_res, error=0): + if isinstance(mnp_res, (tuple, list)): + assert len(mnp_res) == len(onp_res) + for actual, expected in zip(mnp_res, onp_res): + match_array(actual.asnumpy(), expected, error) + else: + match_array(mnp_res.asnumpy(), onp_res, error) + + +def match_meta(actual, expected): + # float64 and int64 are not supported, and the default type for + # float and int are float32 and int32, respectively + if expected.dtype == onp.float64: + expected = expected.astype(onp.float32) + elif expected.dtype == onp.int64: + expected = expected.astype(onp.int32) + assert actual.shape == expected.shape + assert actual.dtype == expected.dtype + + +def run_binop_test(mnp_fn, onp_fn, test_case): + for arr in test_case.arrs: + match_res(mnp_fn, onp_fn, arr, arr) + + for scalar in test_case.scalars: + match_res(mnp_fn, onp_fn, arr, scalar) + match_res(mnp_fn, onp_fn, scalar, arr) + + for scalar1 in test_case.scalars: + for scalar2 in test_case.scalars: + match_res(mnp_fn, onp_fn, scalar1, scalar2) + + for expanded_arr1 in test_case.expanded_arrs: + for expanded_arr2 in test_case.expanded_arrs: + match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2) + + for broadcastable1 in test_case.broadcastables: + for broadcastable2 in test_case.broadcastables: + match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2) + + +def run_unary_test(mnp_fn, onp_fn, test_case, error=0): + for arr in test_case.arrs: + match_res(mnp_fn, onp_fn, arr, error=error) + + for arr in test_case.scalars: + match_res(mnp_fn, onp_fn, arr, error=error) + + for arr in test_case.expanded_arrs: + match_res(mnp_fn, onp_fn, arr, error=error) + + +def run_multi_test(mnp_fn, onp_fn, arrs, error=0): + mnp_arrs = map(mnp.asarray, arrs) + for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)): + match_array(actual.asnumpy(), expected, error) + + +def run_single_test(mnp_fn, onp_fn, arr, error=0): + mnp_arr = mnp.asarray(arr) + for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)): + if isinstance(expected, tuple): + for actual_arr, expected_arr in zip(actual, expected): + match_array(actual_arr.asnumpy(), expected_arr, error) + match_array(actual.asnumpy(), expected, error)