From eab5ecf21a21b56a9d9b7b28022205c71c11fb38 Mon Sep 17 00:00:00 2001 From: yanglf1121 Date: Mon, 18 Jan 2021 09:26:02 +0800 Subject: [PATCH] Fix bugs in np.atleastxd, np.xstack and format api docstrings --- mindspore/_checkparam.py | 2 +- mindspore/_extends/parse/standard_method.py | 2 +- mindspore/common/tensor.py | 2 +- mindspore/numpy/__init__.py | 10 +- mindspore/numpy/array_creations.py | 18 +-- mindspore/numpy/array_ops.py | 148 +++++++----------- mindspore/numpy/dtypes.py | 5 +- mindspore/numpy/math_ops.py | 37 ++--- mindspore/numpy/utils.py | 4 +- mindspore/numpy/utils_const.py | 67 +------- tests/st/numpy_native/__init__.py | 2 +- tests/st/numpy_native/test_array_creations.py | 4 +- tests/st/numpy_native/test_array_ops.py | 5 +- tests/st/numpy_native/test_math_ops.py | 2 +- 14 files changed, 100 insertions(+), 208 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index f247be9c88..028084ccca 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore/_extends/parse/standard_method.py b/mindspore/_extends/parse/standard_method.py index 03c940f1b9..43d3d20f6c 100644 --- a/mindspore/_extends/parse/standard_method.py +++ b/mindspore/_extends/parse/standard_method.py @@ -1,6 +1,6 @@ # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). # -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 5e8cbbc857..0610db73d2 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/mindspore/numpy/__init__.py b/mindspore/numpy/__init__.py index add6582f53..726ef318ac 100644 --- a/mindspore/numpy/__init__.py +++ b/mindspore/numpy/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,10 +19,10 @@ Examples: >>> import mindspore.numpy as np Note: - - array_ops.py define all the array generation and operation interfaces. - - math_ops.py define all the math operations on tensors. - - dtypes.py define all the mindspore.numpy dtypes (mainly redirected from mindspore) - - random/ defines all the random operations. + - array_ops.py defines all the array operation interfaces. + - array_creations.py defines all the array generation interfaces. + - math_ops.py defines all the math operations on tensors. + - dtypes.py defines all the mindspore.numpy dtypes (mainly redirected from mindspore) """ from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape, diff --git a/mindspore/numpy/array_creations.py b/mindspore/numpy/array_creations.py index 0fe486dbd4..3cfb503e34 100644 --- a/mindspore/numpy/array_creations.py +++ b/mindspore/numpy/array_creations.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -153,7 +153,7 @@ def asarray(a, dtype=None): elif a.dtype is onp.dtype('float'): dtype = mstype.float32 elif a.dtype is onp.dtype('object'): - raise TypeError(f"For Tensor convertion, the input_data is {a} that contains unsupported element.") + raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") a = Tensor.from_numpy(a) # If a is already a tensor and we don't need to cast dtype, return a @@ -208,7 +208,7 @@ def asfarray(a, dtype=mstype.float32): a = _deep_tensor_to_nparray(a) a = onp.asarray(a) if a.dtype is onp.dtype('object'): - raise TypeError(f"For Tensor convertion, the input_data is {a} that contains unsupported element.") + raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") if isinstance(a, onp.ndarray): a = Tensor.from_numpy(a) @@ -952,7 +952,7 @@ def tril(m, k=0): Returns a copy of an array with elements above the k-th diagonal zeroed. Args: - m(array_like): The shape and data-type of a define these same + m(array_like): The shape and data-type of m define these same attributes of the returned array. k(int, optional): Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above. @@ -987,16 +987,16 @@ def triu(m, k=0): """ Returns an upper triangle of an array. - Returns a copy of an array with elements above the k-th diagonal zeroed. + Returns a copy of an array with elements below the k-th diagonal zeroed. Args: - m(array_like): The shape and data-type of a define these same + m(array_like): The shape and data-type of m define these same attributes of the returned array. - k(int, optional): Diagonal above which to zero elements. k = 0 (the default) + k(int, optional): Diagonal below which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above. Returns: - triu(Tensor): Lower triangle of m, of same shape and data-type as m. + triu(Tensor): Upper triangle of m, of same shape and data-type as m. Raises: TypeError: If input arguments have types not specified above. @@ -1175,7 +1175,7 @@ def trace(a, offset=0, axis1=0, axis2=1): >>> print(output) [6 8] >>> a = np.arange(24).reshape((2,2,2,3)) - >>> output = np.trace.shape + >>> output = np.trace(a).shape >>> print(output) (2, 3) """ diff --git a/mindspore/numpy/array_ops.py b/mindspore/numpy/array_ops.py index b26b92a133..5374493145 100644 --- a/mindspore/numpy/array_ops.py +++ b/mindspore/numpy/array_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,11 @@ from ..ops import functional as F from ..ops.primitive import constexpr from ..nn import Cell -from .utils import _covert_list_tensor_to_tuple_tensor, _expand, _broadcast_to, \ +from .utils import _convert_list_tensor_to_tuple_tensor, _expand, _broadcast_to, \ _is_empty from .utils_const import _check_is_int, _check_axes_range, _check_start_normalize, \ _check_is_tensor, _check_is_tuple, _check_is_list, _raise_type_error, _raise_value_error, \ - _infer_out_shape, _get_index_for_unique, _get_counts_for_unique, _empty, _promote, \ - _min, _check_same_type, _check_input_tensor + _infer_out_shape, _empty, _promote, _check_same_type, _check_input_tensor # According to official numpy reference, the dimension of a numpy array must be less # than 32 @@ -336,7 +335,7 @@ def ravel(x): Flattened tensor, has the same data type as the original tensor x. Raises: - If x is not tensor. + TypeError: If x is not tensor. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -450,7 +449,7 @@ def concatenate(arrays, axis=0): return P.Concat(axis)(flattened_arrays) # convert a list of tensor to a tuple of tensor - arrays = _covert_list_tensor_to_tuple_tensor(arrays) + arrays = _convert_list_tensor_to_tuple_tensor(arrays) arr_shape = F.shape(arrays[0]) _check_axes_range((axis,), len(arr_shape)) @@ -503,12 +502,11 @@ def column_stack(tup): trans_tup = () for tensor in tup: - shape = F.shape(tensor) - if F.tuple_len(shape) == 1: - reshape_tensor = F.reshape(tensor, shape+(1,)) - trans_tup += (reshape_tensor,) - else: - trans_tup += (tensor,) + if tensor.ndim < 1: + tensor = F.expand_dims(tensor, 0) + if tensor.ndim == 1: + tensor = F.expand_dims(tensor, 1) + trans_tup += (tensor,) return P.Concat(axis=1)(trans_tup) @@ -552,12 +550,9 @@ def vstack(tup): trans_tup = () for tensor in tup: - shape = F.shape(tensor) - if F.tuple_len(shape) == 1: - reshape_tensor = F.reshape(tensor, (1,)+shape) - trans_tup += (reshape_tensor,) - else: - trans_tup += (tensor,) + if tensor.ndim <= 1: + tensor = _expand(tensor, 2, 0) + trans_tup += (tensor,) return P.Concat(axis=0)(trans_tup) @@ -600,13 +595,12 @@ def hstack(tup): _raise_value_error("Need at least one tensor to concatenate.") tuple_of_tensor = () - if _check_is_list(tup): - for tensor in tup: - tuple_of_tensor += (tensor,) - else: - tuple_of_tensor = tup + for tensor in tup: + if tensor.ndim < 1: + tensor = F.expand_dims(tensor, 0) + tuple_of_tensor += (tensor,) - if F.tuple_len(F.shape(tup[0])) == 1: + if tuple_of_tensor[0].ndim <= 1: return P.Concat(axis=0)(tuple_of_tensor) return P.Concat(axis=1)(tuple_of_tensor) @@ -652,15 +646,11 @@ def dstack(tup): trans_tup = () for tensor in tup: - shape = F.shape(tensor) - if F.tuple_len(shape) == 1: - reshape_tensor = F.reshape(tensor, (1,)+shape+(1,)) - trans_tup += (reshape_tensor,) - elif F.tuple_len(shape) == 2: - reshape_tensor = F.reshape(tensor, shape+(1,)) - trans_tup += (reshape_tensor,) - else: - trans_tup += (tensor,) + if tensor.ndim <= 1: + tensor = _expand(tensor, 2, 0) + if tensor.ndim == 2: + tensor = F.expand_dims(tensor, 2) + trans_tup += (tensor,) return P.Concat(axis=2)(trans_tup) @@ -670,10 +660,6 @@ def where(condition, x=None, y=None): Note: As nonzero is not supported, neither x or y can be None. - On CPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. - On GPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. Args: condition (Tensor): where True, yield x, otherwise yield y. @@ -724,6 +710,9 @@ def where(condition, x=None, y=None): shape_out = _infer_out_shape(F.shape(condition), F.shape(x), F.shape(y)) ndim_out = len(shape_out) + if not _check_same_type(F.dtype(condition), mstype.float32): + # tiling with bool is not supported on GPU + condition = F.cast(condition, mstype.float32) condition = _expand(condition, ndim_out) x = _expand(x, ndim_out) y = _expand(y, ndim_out) @@ -739,24 +728,16 @@ def where(condition, x=None, y=None): return res -def _expand_atleast(arr, ndim): - """Expands arr to at least ndim.""" - arr = _expand(arr, _min(ndim, 2)) - if ndim > 2: - arr = _expand(arr, ndim, axis=-1) - return arr - - def _atleast_xd(ndim, arys): """Returns arys with at least ndim.""" for arr in arys: _check_input_tensor(F.typeof(arr)) - - if F.tuple_len(arys) == 1: - return _expand_atleast(*arys, ndim) res = [] - for arr in res: - res.append(_expand_atleast(arr, ndim)) + for arr in arys: + arr = _expand(arr, ndim) + res.append(arr) + if len(res) == 1: + return res[0] return res @@ -770,10 +751,6 @@ def atleast_1d(*arys): Note: In graph mode, returns a tuple of tensor instead of a list of tensors. - On CPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. - On GPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. Args: arys1, arys2, … (Tensor): one or more input tensors. @@ -810,10 +787,6 @@ def atleast_2d(*arys): Note: In graph mode, returns a tuple of tensor instead of a list of tensors. - On CPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. - On GPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. Args: arys1, arys2, … (Tensor): one or more input tensors. @@ -850,10 +823,7 @@ def atleast_3d(*arys): Note: In graph mode, returns a tuple of tensor instead of a list of tensors. - On CPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. - On GPU, the supported dtypes are np.float16, np.float32, np.int16, - and np.int32. + Args: arys1, arys2, … (Tensor): one or more input tensors. @@ -882,7 +852,19 @@ def atleast_3d(*arys): value= [[[1.00000000e+000], [1.00000000e+000], [1.00000000e+000], [1.00000000e+000], [1.00000000e+000]]])) """ - return _atleast_xd(3, arys) + res = [] + for arr in arys: + ndim = F.rank(arr) + if ndim == 0: + arr = F.reshape(arr, (1, 1, 1)) + elif ndim == 1: + arr = F.reshape(arr, (1, F.size(arr), 1)) + elif ndim == 2: + arr = F.reshape(arr, F.shape(arr) + (1,)) + res.append(arr) + if len(res) == 1: + return res[0] + return res def stack(arrays, axis=0): @@ -960,31 +942,24 @@ class UniqueNet(Cell): return self.unique(x) -def unique(x, return_index=False, return_inverse=False, return_counts=False): +def unique(x, return_inverse=False): """ Finds the unique elements of a tensor. The input tensor will be flattened first when it has more than one dimension. Note: - The operation is derived from mindspore.ops.Unique. - Numpy arguments `axis` is not supported. + Numpy arguments `axis`, `return_index` and `return_counts` are not supported. + This operator must be executed in graph mode. Args: x (Tensor): The input tensor to be processed. - return_index (bool): If True, also return the indices of tensor x (along - the specified axis, if provided, or in the flattened tensor) that result - in the unique tensor. Default: False. return_inverse (bool): If True, also return the indices of the unique tensor. Default: False. - return_counts (bool): If True, also return the number of times each unique - item appears in input tensor `x`. Default: False. Returns: Tensor or tuple of Tensors. - - If all of the three bool arguments (`return_index`, `return_inverse`, `return_counts`) - are False, just return the unique tensor. - - If parts of the three bool arguments are True, the corresponding results (Tensor) - will be added in the tuple. + - If `return_inverse` is False, just return the unique tensor. + - If `return_inverse` is True, return tuple of tensors. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` @@ -995,14 +970,12 @@ def unique(x, return_index=False, return_inverse=False, return_counts=False): Examples: >>> import mindspore.numpy as mnp >>> import numpy as onp + >>> from mindspore import context + >>> context.set_context(mode=context.GRAPH_MODE) >>> input_x = mnp.asarray(onp.array([1, 2, 2, 2, 3, 4, 5]).astype('float32')) >>> output_x = mnp.unique(input_x) >>> print(output_x) [1. 2. 3. 4. 5.] - >>> output_x = mnp.unique(input_x, return_index=True) - >>> print(output_x) - (Tensor(shape=[5], dtype=Float32, value= [ 1. 2. 3. 4. 5.]), Tensor(shape=[5], dtype=Float32, - value= [ 0. 1. 4. 5. 6.])) >>> output_x = mnp.unique(input_x, return_inverse=True) >>> print(output_x) (Tensor(shape=[5], dtype=Float32, value= [ 1. 2. 3. 4. 5.]), Tensor(shape=[7], dtype=Int32, @@ -1013,16 +986,7 @@ def unique(x, return_index=False, return_inverse=False, return_counts=False): if F.tuple_len(F.shape(x)) > 1: x = ravel(x) uniq = UniqueNet() - unique_x, inverse_index = uniq(x) - if not return_index and not return_inverse and not return_counts: - return unique_x - res_tup = (unique_x,) - if return_index: - res_index = _get_index_for_unique(x, unique_x) - res_tup += (res_index,) - if return_inverse: - res_tup += (inverse_index,) - if return_counts: - res_counts = _get_counts_for_unique(x, unique_x) - res_tup += (res_counts,) - return res_tup + res = uniq(x) + if not return_inverse: + return res[0] + return res diff --git a/mindspore/numpy/dtypes.py b/mindspore/numpy/dtypes.py index 211cf9cd1b..acdcb45668 100644 --- a/mindspore/numpy/dtypes.py +++ b/mindspore/numpy/dtypes.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -146,6 +146,9 @@ promotion_rule = { (int32, float16): float16, (int32, float32): float32, (int32, float64): float64, + (int64, float16): float16, + (int64, float32): float32, + (int64, float64): float64, (float16, float32): float32, (float16, float64): float64, (float32, float64): float64, diff --git a/mindspore/numpy/math_ops.py b/mindspore/numpy/math_ops.py index 9d617abcc9..721ba48522 100644 --- a/mindspore/numpy/math_ops.py +++ b/mindspore/numpy/math_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -72,7 +72,8 @@ def absolute(x, out=None, where=True, dtype=None): ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> x = np.asarray([1, 2, 3, -4, -5], np.float64) + >>> import mindspore.numpy as np + >>> x = np.asarray([1, 2, 3, -4, -5], np.float32) >>> output = np.absolute(x) >>> print(output) [1. 2. 3. 4. 5.] @@ -97,10 +98,6 @@ def add(x1, x2, out=None, where=True, dtype=None): Argument out is not supported for storing the result, however it can be used in combination with argument where to set the value at indices for which where is set to False. - On GPU, the supported dtypes are np.float16, np.float32, np.int32, - and np.int64. - On CPU, the supported dtypes are np.float16, np.float32, np.float64, - np.int16, np.int32, and np.int64. Args: x1 (Tensor): input to be added. @@ -154,10 +151,6 @@ def subtract(x1, x2, out=None, where=True, dtype=None): Argument out is not supported for storing the result, however it can be used in combination with argument where to set the value at indices for which where is set to False. - On GPU, the supported dtypes are np.float16, np.float32, np.int32, - and np.int64. - On CPU, the supported dtypes are np.float16, np.float32, np.float64, - np.int16, np.int32, and np.int64. Args: x1 (Tensor): the input to be subtracted from. @@ -207,10 +200,6 @@ def multiply(x1, x2, out=None, where=True, dtype=None): Argument out is not supported for storing the result, however it can be used in combination with argument where to set the value at indices for which where is set to False. - On GPU, the supported dtypes are np.float16, np.float32, np.int32, - and np.int64. - On CPU, the supported dtypes are np.float16, np.float32, np.float64, - np.int16, np.int32, and np.int64. Args: x1 (Tensor): input tensor to be multiplied. @@ -273,8 +262,6 @@ def divide(x1, x2, out=None, where=True, dtype=None): used in combination with argument where to set the value at indices for which where is set to False. On GPU, the supported dtypes are np.float16, and np.float32. - On CPU, the supported dtypes are np.float16, np.float32, np.float64, - np.int16, np.int32, and np.int64. Args: x1 (Tensor): the divident. @@ -325,12 +312,10 @@ def power(x1, x2, out=None, where=True, dtype=None): Numpy arguments casting, order, dtype, subok, signature, and extobj are not supported. On GPU, the supported dtypes are np.float16, and np.float32. - On CPU, the supported dtypes are np.float16, np.float32, np.float64, - np.int16, np.int32, and np.int64. Args: x1 (Tensor): the bases. - x2 (Tensor): the exponenets. + x2 (Tensor): the exponents. out (Tensor or None): optional, defaults to None. where (Tensor or None): optional. For any non-default value of type other than Tensor or None, the output retains its original value. @@ -345,7 +330,7 @@ def power(x1, x2, out=None, where=True, dtype=None): Returns: Tensor or scalar, the bases in x1 raised to the exponents in x2. This - is a scalarif both x1 and x2 are scalars. + is a scalar if both x1 and x2 are scalars. Raises: TypeError: if the input is not a tensor. @@ -354,8 +339,8 @@ def power(x1, x2, out=None, where=True, dtype=None): ``Ascend`` ``GPU`` ``CPU`` Examples: - >>> x1 = np.full((3, 2), [1, 2]) - >>> x2 = np.full((3, 2), [3, 4]) + >>> x1 = np.full((3, 2), [1, 2]).astype('float32') + >>> x2 = np.full((3, 2), [3, 4]).astype('float32') >>> output = np.power(x1, x2) >>> print(output) [[ 1, 16], @@ -548,8 +533,8 @@ def dot(a, b): Examples: >>> import mindspore.numpy as np - >>> a = np.full((1, 3), 7) - >>> b = np.full((2, 3, 4), 5) + >>> a = np.full((1, 3), 7).astype('float32') + >>> b = np.full((2, 3, 4), 5).astype('float32') >>> output = np.dot(a, b) >>> print(output) [[[105, 105, 105, 105], @@ -597,8 +582,8 @@ def outer(a, b): Examples: >>> import mindspore.numpy as np - >>> a = np.full(7, 2) - >>> b = np.full(4, 3) + >>> a = np.full(7, 2).astype('float32') + >>> b = np.full(4, 3).astype('float32') >>> output = np.outer(a, b) >>> print(output) [[6, 6, 6, 6], diff --git a/mindspore/numpy/utils.py b/mindspore/numpy/utils.py index 9508c90cc2..54e78a485e 100644 --- a/mindspore/numpy/utils.py +++ b/mindspore/numpy/utils.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -77,7 +77,7 @@ def _get_device(): return context.get_context('device_target') -def _covert_list_tensor_to_tuple_tensor(list_of_tensor): +def _convert_list_tensor_to_tuple_tensor(list_of_tensor): """Convert a list of tensor to a tuple of tensor""" if isinstance(list_of_tensor, list): tuple_of_tensor = () diff --git a/mindspore/numpy/utils_const.py b/mindspore/numpy/utils_const.py index faab390073..7c5544c2ae 100644 --- a/mindspore/numpy/utils_const.py +++ b/mindspore/numpy/utils_const.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +15,7 @@ """internal graph-compatible utility functions""" from functools import partial -import numpy as onp - import mindspore.context as context -from ..common import Tensor from ..ops import functional as F from ..ops.primitive import constexpr from ..common import dtype as mstype @@ -262,65 +259,6 @@ def _empty(dtype, shape): return Tensor_(dtype, shape) -def _get_index_for_unique(input_x, unique_x): - """ - Return the indices of the first occurrences of the unique values in the original array. - - Args: - input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`. - unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order. - - Returns: - Tensor. The indices of the unique values in the original array. Has the same shape as `unique_x`. - """ - o_array = input_x.asnumpy() - dic = {} - for idx in range(o_array.size): - val = o_array[idx] - if val not in dic: - dic[val] = idx - - index_lst = [] - u_array = unique_x.asnumpy() - for idx in range(u_array.size): - index_lst.append(dic[u_array[idx]]) - - return Tensor(onp.array(index_lst), input_x.dtype) - - -@constexpr -def _get_counts_for_unique(input_x, unique_x): - """ - Return the number of times each of the unique values comes up in the original tensor. - - Args: - input_x (Tensor): The flattened input tensor of `mindspore.numpy.unique`. - unique_x (Tensor): The tensor contains the unique elements in `input_x`, sorted in ascending order. - - Returns: - Tensor. The number of times each of the unique values comes up in the original tensor. - """ - dic = {} - o_array = input_x.asnumpy() - for idx in range(o_array.size): - val = o_array[idx] - if val not in dic: - dic[val] = 1 - else: - dic[val] += 1 - - u_array = unique_x.asnumpy() - counts_lst = [dic[val] for val in u_array] - - return Tensor(onp.array(counts_lst), input_x.dtype) - - -@constexpr -def _get_max_value(x): - """Returns the maximum value of the input tensor `x`. """ - return int(max(x.asnumpy())) - - @constexpr def _promote(dtype1, dtype2): if dtype1 == dtype2: @@ -355,7 +293,8 @@ def _check_same_type(dtype1, dtype2): @constexpr def _check_is_float(dtype): - return dtype in mstype.float_type + """Returns whether dtype is float16 or float32.""" + return dtype in (mstype.float16, mstype.float32) @constexpr diff --git a/tests/st/numpy_native/__init__.py b/tests/st/numpy_native/__init__.py index 449438f400..bd8949921a 100644 --- a/tests/st/numpy_native/__init__.py +++ b/tests/st/numpy_native/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/st/numpy_native/test_array_creations.py b/tests/st/numpy_native/test_array_creations.py index 635b3d67f4..b60e4027be 100644 --- a/tests/st/numpy_native/test_array_creations.py +++ b/tests/st/numpy_native/test_array_creations.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -194,7 +194,7 @@ def match_all_arrays(mnp_res, onp_res, error=0): def match_meta(actual, expected): - # float64 and int64 are not supported, and the defualt type for + # float64 and int64 are not supported, and the default type for # float and int are float32 and int32, respectively if expected.dtype == onp.float64: expected = expected.astype(onp.float32) diff --git a/tests/st/numpy_native/test_array_ops.py b/tests/st/numpy_native/test_array_ops.py index 48b3634b19..c27bcdc855 100644 --- a/tests/st/numpy_native/test_array_ops.py +++ b/tests/st/numpy_native/test_array_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -168,6 +168,7 @@ def match_res(mnp_fn, onp_fn, *arrs, **kwargs): def match_all_arrays(mnp_res, onp_res, error=0): if isinstance(mnp_res, (tuple, list)): + assert len(mnp_res) == len(onp_res) for actual, expected in zip(mnp_res, onp_res): match_array(actual.asnumpy(), expected, error) else: @@ -175,7 +176,7 @@ def match_all_arrays(mnp_res, onp_res, error=0): def match_meta(actual, expected): - # float64 and int64 are not supported, and the defualt type for + # float64 and int64 are not supported, and the default type for # float and int are float32 and int32, respectively if expected.dtype == onp.float64: expected = expected.astype(onp.float32) diff --git a/tests/st/numpy_native/test_math_ops.py b/tests/st/numpy_native/test_math_ops.py index 9b430fcc92..e72ac0fcbe 100644 --- a/tests/st/numpy_native/test_math_ops.py +++ b/tests/st/numpy_native/test_math_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.