!10755 [Numpy-Native] Add new numpy-native interfaces to mindspore.numpy

From: @yanglf1121
Reviewed-by: 
Signed-off-by:
pull/10755/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit eacc8bac89

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -595,6 +595,123 @@ class Validator:
raise ValueError(f'For {prim_name}, {ori_shape} reduce on {axis} should be '
f'{tuple(exp_shape)}, but got {shape}.')
@staticmethod
def check_astype_dtype(dtype):
"""Check whether dtype is a valid input, and convert to mstype"""
all_types = mstype.__dtype__ + ["int", "float", "bool"]
if isinstance(dtype, str):
if dtype.lower() not in all_types:
raise TypeError(f"`{dtype}` not understood.")
dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
elif isinstance(dtype, type):
dtype = mstype.pytype_to_dtype(dtype)
elif not dtype in mstype.number_type + (mstype.bool_,):
raise TypeError(f"`{dtype}` not understood.")
return dtype
@staticmethod
def check_transpose_axis(axes, ndim):
"""Check the axis argument for tensor.transpose"""
if not axes or (len(axes) == 1 and axes[0] is None):
return tuple(range(ndim-1, -1, -1))
if len(axes) == 1:
perm = axes[0]
# if only one argument provided, it must be tuple or list
if isinstance(perm, list):
perm = tuple(perm)
else:
if not isinstance(perm, tuple):
raise TypeError(f"The `axes` should be a tuple/list, or series of int, but got {type(axes[0])}")
return perm
# if multiple arguments provided, it must be `ndim` number of ints
if len(axes) != ndim:
raise ValueError("The number of axes must equal to the dimension of tensor.")
return axes
@staticmethod
def check_reshape_shp(shp):
"""Check the shape argument for tensor.reshape"""
if len(shp) == 1:
new_shape = shp[0]
# if only one argument provided, it must be int, tuple or list
if isinstance(new_shape, int):
return shp
if isinstance(new_shape, list):
new_shape = tuple(new_shape)
else:
if not isinstance(new_shape, tuple):
raise TypeError(
f"The `shape` should be an int, or tuple/list, or series of int, but got {type(shp[0])}")
return new_shape
return shp
@staticmethod
def check_flatten_order(order):
"""Check flatten function input order"""
if not isinstance(order, str):
raise TypeError(f"The order variable should be a string, but got {type(order)}")
if order not in ('C', 'F'):
raise ValueError(f"only `C` and `F` are supported as order, but got {order}")
return order
@staticmethod
def check_swapaxes_axis(axes, ndim):
"""Check all the axes argument for tensor.swapaxes"""
if isinstance(axes, int):
check_axis_in_range(axes, ndim)
return axes % ndim
if isinstance(axes, (tuple, list)):
for axis in axes:
if not isinstance(axis, int):
raise TypeError(f"axis argument should be integer, but got {type(axis)}.")
check_axis_in_range(axis, ndim)
axes = tuple(map(lambda x: x % ndim, axes))
return axes
raise TypeError(f"axes should be integer, list or tuple for check, but got {type(axes)}.")
@staticmethod
def prepare_shape_for_squeeze(shape, axes):
"""
Creates the squeezed new shape based on the tensor and given axes.
Args:
shape (tuple): the shape of the tensor
axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
be squeezed.
Returns:
new_shape(tuple): the shape with dimensions squeezed.
"""
new_shape = []
ndim = len(shape)
# Convert to set
if isinstance(axes, int):
if axes >= ndim or axes < -ndim:
raise ValueError(f"axis {axes} is out of bounds for tensor of dimension {ndim}")
axes = {axes}
elif isinstance(axes, (list, tuple)):
for axis in axes:
if axis >= ndim or axis < -ndim:
raise ValueError(f"axis {axis} is out of bounds for tensor of dimension {ndim}")
axes = set(axes)
else:
raise TypeError(f"only int, tuple and list are allowed for axes, but got {type(axes)}")
for idx, s in enumerate(shape):
if s != 1 or (idx not in axes) and (idx - ndim not in axes):
new_shape.append(s)
# if an axis is selected with shape entry greater than one, an error is raised.
if s != 1 and ((idx in axes) or (idx - ndim in axes)):
raise ValueError(f"axis {axes} has shape entry {s} > 1, cannot be squeezed.")
return tuple(new_shape)
def check_input_format(input_param):
"""Judge input format."""
@ -623,6 +740,13 @@ def _expand_tuple(n_dimensions):
return convert
def check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
if -ndim <= axis < ndim:
return True
raise ValueError(f'axis {axis} is out of bounds for tensor of dimension {ndim}')
def _check_data_type_valid(data, valid_type):
"""Check data type valid."""
if valid_type is None:

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +15,13 @@
# limitations under the License.
# ============================================================================
"""standard_method"""
from dataclasses import dataclass
from mindspore import Tensor
from mindspore import dtype as mstype
from ..._checkparam import Validator as validator
from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
@ -28,14 +31,17 @@ from ...ops.primitive import constexpr
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
trans = P.Transpose()
shape_ = P.Shape()
reshape_ = P.Reshape()
dtype_ = P.DType()
abs_ = P.Abs()
ndim_ = P.Rank()
size_ = P.Size()
itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1,
mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2,
mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4,
mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8}
def mean(x, axis=(), keep_dims=False):
"""
Reduces a dimension of a tensor by averaging all elements in the dimension.
@ -93,23 +99,150 @@ def any_(x, axis=(), keep_dims=False):
return reduce_any(x, axis)
def itemsize_(x):
"""
Return length of one tensor element in bytes.
Args:
x (Tensor): Input tensor.
Returns:
itemsize(int).
"""
return get_itemsize(x.dtype)
def nbytes_(x):
"""
Return total number of bytes taken by the tensor.
Args:
x (Tensor): Input tensor.
Returns:
nbytes(int).
"""
return itemsize_(x) * F.shape_mul(shape_(x))
def strides_(x):
"""
Return the tuple of bytes to step in each dimension when traversing a tensor.
Args:
x (Tensor): Input tensor.
Returns:
strides (tuple[int]).
"""
strides = ()
ndim = P.Rank()(x)
tensor_shape = shape_(x)
for i in F.make_range(0, ndim):
stride = itemsize_(x)
for j in F.make_range(i + 1, ndim):
stride *= tensor_shape[j]
strides += (stride,)
return strides
def astype(x, dtype, copy=True):
"""Implementation of `astype`."""
dtype = check_astype_dtype_const(dtype)
if not copy and dtype == x.dtype:
return x
return F.cast(x, dtype)
def transpose(x, *axis):
"""Implementation of `transpose`."""
new_order = None
shape = F.shape(x)
length = F.tuple_len(shape)
if not axis:
perm = F.make_range(0, length)
new_order = F.tuple_reversed(perm)
ndim = F.rank(x)
perm = check_transpose_axis_const(axis, ndim)
return F.transpose(x, perm)
elif len(axis) == 1:
new_order = convert_list_to_tuple(axis[0])
# `tensor.T` is used as a property in graph mode
T_ = transpose
elif len(axis) == length:
new_order = axis
out = trans(x, new_order)
return out
def reshape(x, *shape):
"""Implementation of `reshape`."""
new_shape = check_reshape_shp_const(shape)
return F.reshape(x, new_shape)
def ravel(x):
"""Implementation of `ravel`."""
return reshape(x, (-1,))
def flatten(x, order='C'):
"""
Returns a copy of the array collapsed into one dimension.
Args:
order (str, optional): Can choose between `C` and `F`. `C` means to
flatten in row-major (C-style) order. F means to flatten in column-major
(Fortran- style) order. Only `C` and `F` are supported.
Returns:
Tensor, has the same data type as x.
"""
order = check_flatten_order_const(order)
if order == 'C':
return F.reshape(x, (-1,))
perm = F.make_range(0, F.rank(x))
new_order = F.tuple_reversed(perm)
return F.reshape(F.transpose(x, new_order), (-1,))
def swapaxes(x, axis1, axis2):
"""
Interchanges two axes of a tensor.
Args:
axis1 (int): First axis.
axis2 (int): Second axis.
Returns:
Transposed tensor, has the same data type as the original tensor x.
"""
axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim)
if axis1 == axis2:
return x
if axis1 > axis2:
axis1, axis2 = axis2, axis1
perm = F.make_range(0, x.ndim)
new_perm = None
if axis2 + 1 < x.ndim:
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
else:
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
perm[axis1+1:axis2] + perm[axis1:axis1+1]
return F.transpose(x, new_perm)
def squeeze(x, axis=None):
"""
Removes single-dimensional entries from the shape of an tensor.
Args:
axis: Union[None, int, list(int), tuple(list)]. Default is None.
Returns:
Tensor, with all or a subset of the dimensions of length 1 removed.
"""
shape = F.shape(x)
if axis is None:
return F.squeeze(x)
# yield squeezed shape based on the axes
new_shape = prepare_shape_for_squeeze_const(shape, axis)
return F.reshape(x, new_shape)
def getitem(data, item):
@ -200,7 +333,7 @@ def expand_tensor_as(x, y):
def view(x, *shape):
"""Reshape tensor, if shape is -1, reshape tensor into one dimension"""
shape = check_view_shape(shape)
return reshape_(x, shape)
return F.reshape(x, shape)
def isinstance_(x, base_type):
@ -240,6 +373,12 @@ def check_type_same(x_type, base_type):
raise TypeError(f"The type '{base_type}' is not supported for 'isinstance'")
@constexpr
def get_itemsize(x_type):
"""get itemsize from tensor's dtype."""
return itemsize_map[x_type]
@constexpr
def check_is_tensor(x):
"""check whether x is tensor."""
@ -298,14 +437,14 @@ def check_view_shape(x):
x = x[0]
return x
@constexpr
def convert_list_to_tuple(shp):
"""Check the type of the shape, if is list, convert to tuple"""
if not isinstance(shp, (list, tuple)):
raise ValueError(f"The shape variable should be a list or tuple, but got {type(shp)}")
if isinstance(shp, list):
shp = tuple(shp)
return shp
# convert noraml param_check functions to constexpr functions
check_astype_dtype_const = constexpr(validator.check_astype_dtype)
check_transpose_axis_const = constexpr(validator.check_transpose_axis)
check_reshape_shp_const = constexpr(validator.check_reshape_shp)
check_flatten_order_const = constexpr(validator.check_flatten_order)
check_swapaxes_axis_const = constexpr(validator.check_swapaxes_axis)
prepare_shape_for_squeeze_const = constexpr(validator.prepare_shape_for_squeeze)
def tensor_bool(x):
"""tensor as conditon, if is constant, return immediate bool value"""

@ -178,6 +178,12 @@ BuiltInTypeMap &GetMethodMap() {
{"__ms_to_array__", prim::kPrimIdentity}, // P.identity,
{"item", prim::kPrimArrayToScalar}, // P.array_to_scalar,
{"transpose", std::string("transpose")}, // P.transpose
{"flatten", std::string("flatten")}, // P.reshape(,-1)
{"reshape", std::string("reshape")}, // P.reshape()
{"ravel", std::string("ravel")}, // P.reshape(,(-1,))
{"swapaxes", std::string("swapaxes")}, // P.transpose()
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"astype", std::string("astype")}, // P.cast()
{"__bool__", std::string("tensor_bool")}, // C.tensor_bool
}},
{kObjectTypeJTagged, {}},
@ -190,10 +196,14 @@ BuiltInTypeMap &GetAttrMap() {
static BuiltInTypeMap attr_map = {
{kObjectTypeTensorType,
{
{"shape", std::string("shape_")}, // C.shape_
{"dtype", std::string("dtype_")}, // C.dtype_
{"size", std::string("size_")}, // C.size_
{"ndim", std::string("ndim_")}, // C.ndim_
{"shape", std::string("shape_")}, // C.shape_
{"dtype", std::string("dtype_")}, // C.dtype_
{"size", std::string("size_")}, // C.size_
{"ndim", std::string("ndim_")}, // C.ndim_
{"T", std::string("T_")}, // C.T_
{"itemsize", std::string("itemsize_")}, // C.itemsize_
{"nbytes", std::string("nbytes_")}, // C.nbytes_
{"strides", std::string("strides_")}, // C.strides_
}},
{kObjectTypeRowTensorType,
{

@ -258,6 +258,20 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
return dims;
}
py::tuple TensorPy::GetPyTupleStrides(const Tensor &tensor) {
std::vector<ssize_t> shape(tensor.shape().begin(), tensor.shape().end());
std::vector<ssize_t> strides = GetStrides(shape, tensor.data().itemsize());
py::tuple py_strides(strides.size());
for (size_t i = 0; i < strides.size(); ++i) {
py_strides[i] = py::int_(strides[i]);
}
return py_strides;
}
py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().itemsize(); }
py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
{
py::gil_scoped_release gil_release;
@ -383,6 +397,40 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.size
6
)mydelimiter")
.def_property_readonly("_itemsize", TensorPy::GetPyItemSize, R"mydelimiter(
Get the tensor's length of one element in bytes.
Returns:
itemsize, length of one element in bytes.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
>>> data.itemsize
4
)mydelimiter")
.def_property_readonly("_nbytes", TensorPy::GetPyNBytes, R"mydelimiter(
Get the tensor's total number of bytes.
Returns:
nbytes, total number of bytes taken by the tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
>>> data.nbytes
4
)mydelimiter")
.def_property_readonly("_strides", TensorPy::GetPyTupleStrides, R"mydelimiter(
Get the tensor's tuple of bytes to step in each dimension
when traversing an array.
Returns:
tuple[int], the strides of the tensor.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 1), np.int32))
>>> data.strides
(4, 4)
)mydelimiter")
.def("from_numpy", TensorPy::MakeTensorNoCopy, R"mydelimiter(
Creates a Tensor from a numpy.ndarray without copy.

@ -109,6 +109,12 @@ class TensorPy {
static py::array AsNumpy(const Tensor &tensor);
static py::tuple GetPyTupleShape(const Tensor &tensor);
static py::tuple GetPyTupleStrides(const Tensor &tensor);
static py::int_ GetPyItemSize(const Tensor &tensor);
static py::int_ GetPyNBytes(const Tensor &tensor);
};
} // namespace tensor
} // namespace mindspore

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -267,6 +267,26 @@ class Tensor(Tensor_):
"""tensor is inited."""
return self.init is not None
@property
def itemsize(self):
"""The length of one tensor element in bytes."""
return self._itemsize
@property
def strides(self):
"""The tuple of bytes to step in each dimension when traversing a tensor."""
return self._strides
@property
def nbytes(self):
"""The total number of bytes taken by the tensor."""
return self._nbytes
@property
def T(self):
"""The transposed tensor."""
return self.transpose()
@property
def virtual_flag(self):
"""Mark tensor is virtual."""
@ -384,6 +404,145 @@ class Tensor(Tensor_):
axis = ()
return tensor_operator_registry.get('mean')(keep_dims)(self, axis)
def transpose(self, *axes):
"""
Returns a view of the array with axes transposed.
For a 1-D array this has no effect, as a transposed vector is simply the
same vector. For a 2-D array, this is a standard matrix transpose. For an
n-D array, if axes are given, their order indicates how the axes are permuted
(see Examples). If axes are not provided and a.shape = (i[0], i[1],...
i[n-2], i[n-1]), then a.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0]).
Args:
axes(Union[None, tuple(int), list(int), n ints], optional):
None or no argument: reverses the order of the axes.
Tuple of ints: i in the j-th place in the tuple means as i-th
axis becomes a.transpose()s j-th axis.
n ints: this form is intended simply as a `convenience alternative
to the tuple form.
Returns:
Tensor, has the same dimension as input tensor, with axes suitably permuted.
"""
perm = validator.check_transpose_axis(axes, self.ndim)
return tensor_operator_registry.get('transpose')()(self, perm)
def reshape(self, *shape):
"""
Gives a new shape to an array without changing its data.
Args:
shape(Union[int, tuple(int), list(int)]): The new shape should be compatible
with the original shape. If an integer, then the result will be a 1-D
array of that length. One shape dimension can be -1. In this case, the
value is inferred from the length of the array and remaining dimensions.
Returns:
reshaped_tensor(Tensor): This will be a new view object if possible;
otherwise, it will be a copy.
"""
new_shape = validator.check_reshape_shp(shape)
return tensor_operator_registry.get('reshape')()(self, new_shape)
def ravel(self):
"""
Returns a contiguous flattened tensor.
A 1-D tensor, containing the elements of the input, is returned.
Returns:
Tensor, has the same data type as x.
"""
reshape_op = tensor_operator_registry.get('reshape')()
return reshape_op(self, (-1,))
def flatten(self, order='C'):
"""
Returns a copy of the array collapsed into one dimension.
Args:
order (str, optional): Can choose between `C` and `F`. `C` means to
flatten in row-major (C-style) order. F means to flatten in column-major
(Fortran- style) order. Only `C` and `F` are supported.
Returns:
Tensor, has the same data type as x.
"""
reshape_op = tensor_operator_registry.get('reshape')()
trans_op = tensor_operator_registry.get('transpose')()
order = validator.check_flatten_order(order)
if order == 'C':
return reshape_op(self, (-1,))
perm = tuple(range(self.ndim-1, -1, -1))
return reshape_op(trans_op(self, perm), (-1,))
def swapaxes(self, axis1, axis2):
"""
Interchanges two axes of a tensor.
Args:
axis1 (int): First axis.
axis2 (int): Second axis.
Returns:
Transposed tensor, has the same data type as the original tensor x.
"""
axis1, axis2 = validator.check_swapaxes_axis((axis1, axis2), self.ndim)
if axis1 == axis2:
return self
if axis1 > axis2:
axis1, axis2 = axis2, axis1
perm = tuple(range(0, self.ndim))
new_perm = None
if axis2 + 1 < self.ndim:
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
else:
new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
perm[axis1+1:axis2] + perm[axis1:axis1+1]
return tensor_operator_registry.get('transpose')()(self, new_perm)
def squeeze(self, axis=None):
"""
Removes single-dimensional entries from the shape of an tensor.
Args:
axis: Union[None, int, list(int), tuple(list)]. Default is None.
Returns:
Tensor, with all or a subset of the dimensions of length 1 removed.
"""
if axis is None:
return tensor_operator_registry.get('squeeze')(self)
new_shape = validator.prepare_shape_for_squeeze(self.shape, axis)
return tensor_operator_registry.get('reshape')()(self, new_shape)
def astype(self, dtype, copy=True):
"""
Returns a copy of the array, cast to a specified type.
Args:
dtype(Union[mstype.dtype, numpy.dtype, str]): Designated tensor dtype,
can be in format of np.float32, mstype.float32 or `float32`. Default
is mstype.float32.
copy(bool, optional): By default, astype always returns a newly allocated
tensor. If this is set to false, the input tensor is returned instead
of a copy if possible.
Returns:
Tensor, with the designated dtype.
"""
dtype = validator.check_astype_dtype(dtype)
if not copy and dtype == self.dtype:
return self
return tensor_operator_registry.get('cast')(self, dtype)
def init_data(self, slice_index=None, shape=None, opt_shard_group=None):
"""

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -25,22 +25,33 @@ Note:
- random/ defines all the random operations.
"""
from .array_ops import (array, asarray, asfarray, ones, zeros, full, arange,
linspace, logspace, eye, identity, transpose, expand_dims,
squeeze, rollaxis, swapaxes, reshape, ravel, concatenate)
from .array_ops import copy_ as copy
from .array_ops import (transpose, expand_dims, squeeze, rollaxis, swapaxes, reshape,
ravel, concatenate, where, atleast_1d, atleast_2d, atleast_3d,
column_stack, hstack, dstack, vstack, stack, unique)
from .array_creations import copy_ as copy
from .array_creations import (array, asarray, asfarray, ones, zeros, full, arange,
linspace, logspace, eye, identity, empty, empty_like,
ones_like, zeros_like, full_like, diagonal, tril, triu,
tri, trace)
from .dtypes import (int_, int8, int16, int32, int64, uint, uint8, uint16,
uint32, uint64, float_, float16, float32, float64, bool_, inf,
numeric_types)
from .math_ops import mean, inner
from .math_ops import (mean, inner, add, subtract, multiply, divide, power,
dot, outer, tensordot, absolute)
array_ops_module = ['array', 'asarray', 'asfarray', 'copy', 'ones', 'zeros', 'arange',
'linspace', 'logspace', 'eye', 'identity', 'transpose', 'expand_dims',
'squeeze', 'rollaxis', 'swapaxes', 'reshape', 'ravel', 'concatenate']
array_ops_module = ['transpose', 'expand_dims', 'squeeze', 'rollaxis', 'swapaxes', 'reshape',
'ravel', 'concatenate', 'where', 'atleast_1d', 'atleast_2d', 'atleast_3d',
'column_stack', 'hstack', 'dstack', 'vstack', 'stack', 'unique']
math_module = ['mean', 'inner']
array_creations_module = ['array', 'asarray', 'asfarray', 'ones', 'zeros', 'full', 'arange',
'linspace', 'logspace', 'eye', 'identity', 'empty', 'empty_like',
'ones_like', 'zeros_like', 'full_like', 'diagonal', 'tril', 'triu',
'tri', 'trace']
__all__ = array_ops_module + math_module + numeric_types
math_module = ['mean', 'inner', 'add', 'subtract', 'multiply', 'divide', 'power',
'dot', 'outer', 'tensordot', 'absolute']
__all__ = array_ops_module + array_creations_module + math_module + numeric_types
__all__.sort()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,14 +14,15 @@
# ============================================================================
"""Dtypes and utilities"""
from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64, \
float16, float32, float64, bool_)
from ..common.dtype import (int8, int16, int32, int64, uint8, uint16, uint32, uint64,
float16, float32, float64, bool_)
# original numpy has int->int64, float->float64, uint->uint64 mapping. we map
# them to 32 bit, since 64 bit calculation is not supported from mindspore
# backend for now.
inf = float('inf')
nan = float('nan')
int_ = int32
uint = uint32
@ -94,3 +95,69 @@ all_types = [
'np.float32',
'np.float64',
'np.bool']
promotion_rule = {
(uint8, uint16): uint16,
(uint8, uint32): uint32,
(uint8, uint64): uint64,
(uint16, uint32): uint32,
(uint16, uint64): uint64,
(uint32, uint64): uint64,
(uint8, int8): int16,
(uint8, int16): int16,
(uint8, int32): int32,
(uint8, int64): int64,
(uint16, int8): int32,
(uint16, int16): int32,
(uint16, int32): int32,
(uint16, int64): int64,
(uint32, int8): int64,
(uint32, int16): int64,
(uint32, int32): int64,
(uint32, int64): int64,
(uint64, int8): float64,
(uint64, int16): float64,
(uint64, int32): float64,
(uint64, int64): float64,
(uint8, float16): float16,
(uint8, float32): float32,
(uint8, float64): float64,
(uint16, float16): float16,
(uint16, float32): float32,
(uint16, float64): float32,
(uint32, float16): float16,
(uint32, float32): float32,
(uint32, float64): float64,
(uint64, float16): float16,
(uint64, float32): float32,
(uint64, float64): float64,
(int8, int16): int16,
(int8, int32): int32,
(int8, int64): int64,
(int16, int32): int32,
(int16, int64): int64,
(int32, int64): int64,
(int8, float16): float16,
(int8, float32): float32,
(int8, float64): float64,
(int16, float16): float16,
(int16, float32): float32,
(int16, float64): float64,
(int32, float16): float16,
(int32, float32): float32,
(int32, float64): float64,
(float16, float32): float32,
(float16, float64): float64,
(float32, float64): float64,
(bool_, uint8): uint8,
(bool_, uint16): uint16,
(bool_, uint32): uint32,
(bool_, uint64): uint64,
(bool_, int8): int8,
(bool_, int16): int16,
(bool_, int32): int32,
(bool_, int64): int64,
(bool_, float16): float16,
(bool_, float32): float32,
(bool_, float64): float64,
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,6 +1,6 @@
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
#
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -32,6 +32,7 @@ isconstant.set_const_prim(True)
issubclass_ = P.IsSubClass()
isinstance_ = P.IsInstance()
eye = P.Eye()
fill = P.Fill()
tile = P.Tile()
select = P.Select()
@ -45,6 +46,7 @@ control_depend = P.ControlDepend()
merge = P.Merge()
geswitch = P.GeSwitch()
addn = P.AddN()
absolute = P.Abs()
tensor_add = P.TensorAdd()
neg_tensor = P.Neg()
tensor_lt = P.Less()
@ -67,6 +69,8 @@ assign_add = P.AssignAdd()
assign = P.Assign()
square = P.Square()
sqrt = P.Sqrt()
reduce_sum = P.ReduceSum()
tensor_slice = P.Slice()
scalar_to_array = P.ScalarToArray()
scalar_to_tensor = P.ScalarToTensor()
@ -74,6 +78,8 @@ tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast()
print_ = P.Print()
expand_dims = P.ExpandDims()
transpose = P.Transpose()
squeeze = P.Squeeze()
scatter_nd = P.ScatterNd()
gather = P.GatherV2()
gather_nd = P.GatherNd()
@ -177,6 +183,7 @@ tensor_operator_registry.register('any', P.ReduceAny)
tensor_operator_registry.register('abs', P.Abs)
tensor_operator_registry.register('mean', P.ReduceMean)
tensor_operator_registry.register('reshape', P.Reshape)
tensor_operator_registry.register('transpose', P.Transpose)
tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
@ -187,6 +194,7 @@ tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('shape', shape)
tensor_operator_registry.register('squeeze', squeeze)
# support GE backend for no compare operators
tensor_operator_registry.register('cast', cast)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -69,6 +69,24 @@ def test_tensor_size():
assert arr.size == b.size
def test_tensor_itemsize():
arr = np.ones((1, 2, 3))
b = ms.Tensor(arr)
assert arr.itemsize == b.itemsize
def test_tensor_strides():
arr = np.ones((3, 4, 5, 6))
b = ms.Tensor(arr)
assert arr.strides == b.strides
def test_tensor_nbytes():
arr = np.ones((3, 4, 5, 6))
b = ms.Tensor(arr)
assert arr.nbytes == b.nbytes
def test_dtype():
a = ms.Tensor(np.ones((2, 3), dtype=np.int32))
assert a.dtype == ms.int32

File diff suppressed because it is too large Load Diff

@ -1,102 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""unit tests for numpy math operations"""
import pytest
import numpy as onp
import mindspore.numpy as mnp
def rand_int(*shape):
"""return an random integer array with parameter shape"""
res = onp.random.randint(low=1, high=5, size=shape)
if isinstance(res, onp.ndarray):
res = res.astype(onp.float32)
return res
class Cases():
def __init__(self):
self.arrs = [
rand_int(2),
rand_int(2, 3),
rand_int(2, 3, 4),
rand_int(2, 3, 4, 5),
]
# scalars expanded across the 0th dimension
self.scalars = [
rand_int(),
rand_int(1),
rand_int(1, 1),
rand_int(1, 1, 1),
]
# arrays with last dimension aligned
self.aligned_arrs = [
rand_int(2, 3),
rand_int(1, 4, 3),
rand_int(5, 1, 2, 3),
rand_int(4, 2, 1, 1, 3),
]
test_case = Cases()
def mnp_inner(a, b):
return mnp.inner(a, b)
def onp_inner(a, b):
return onp.inner(a, b)
def test_inner():
for arr1 in test_case.aligned_arrs:
for arr2 in test_case.aligned_arrs:
match_res(mnp_inner, onp_inner, arr1, arr2)
for scalar1 in test_case.scalars:
for scalar2 in test_case.scalars:
match_res(mnp_inner, onp_inner,
scalar1, scalar2)
# check if the output from mnp function and onp function applied on the arrays are matched
def match_res(mnp_fn, onp_fn, arr1, arr2):
actual = mnp_fn(mnp.asarray(arr1, dtype='float32'),
mnp.asarray(arr2, dtype='float32')).asnumpy()
expected = onp_fn(arr1, arr2)
match_array(actual, expected)
def match_array(actual, expected, error=5):
if error > 0:
onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
decimal=error)
else:
onp.testing.assert_equal(actual.tolist(), expected.tolist())
def test_exception_innner():
with pytest.raises(ValueError):
mnp.inner(mnp.asarray(test_case.arrs[0]),
mnp.asarray(test_case.arrs[1]))

@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test astype"""
import pytest
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_astype():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.astype("float16")
net = Net()
res = net()
assert res.dtype == mstype.float16
def test_astype_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.int64)
def construct(self):
return self.value.astype(mstype.bool_)
net = Net()
res = net()
assert res.dtype == mstype.bool_
def test_astype_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float64)
def construct(self):
return self.value.astype(mstype.uint64)
net = Net()
res = net()
assert res.dtype == mstype.uint64
def test_astype_error_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.astype("float88")
net = Net()
with pytest.raises(TypeError):
net()

@ -0,0 +1,77 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test flatten"""
import pytest
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_flatten():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.flatten()
net = Net()
net()
def test_flatten_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.flatten(order='F')
net = Net()
net()
def test_flatten_error():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.flatten(order='X')
net = Net()
with pytest.raises(ValueError):
net()
def test_flatten_error_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.flatten(order=123)
net = Net()
with pytest.raises(TypeError):
net()

@ -0,0 +1,103 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test tensor properties in graph mode"""
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_ndim():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.float32)
def construct(self):
return self.value.ndim
net = Net()
res = net()
assert res == 4
def test_nbytes():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.float32)
def construct(self):
return self.value.nbytes
net = Net()
res = net()
assert res == 480
def test_size():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.float32)
def construct(self):
return self.value.size
net = Net()
res = net()
assert res == 120
def test_strides():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.float32)
def construct(self):
return self.value.strides
net = Net()
res = net()
assert res == (240, 80, 20, 4)
def test_itemsize():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value1 = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.float64)
self.value2 = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.int32)
self.value3 = Tensor(np.random.random(
(2, 3, 4, 5)), dtype=mstype.bool_)
def construct(self):
return (self.value1.itemsize, self.value2.itemsize, self.value3.itemsize)
net = Net()
res = net()
assert res == (8, 4, 1)

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test reshape"""
import pytest
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
def test_reshape():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.reshape(-1)
net = Net()
net()
def test_reshape_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.reshape([3, 2, 1])
net = Net()
net()
def test_reshape_2():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.reshape((-1, 2))
net = Net()
net()
def test_reshape_error():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.reshape(1, 2, 4)
net = Net()
with pytest.raises(ValueError):
net()
def test_reshape_error_1():
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.value = Tensor([[1, 2, 3], [4, 5, 6]], dtype=mstype.float32)
def construct(self):
return self.value.reshape((1, 2, 3.5))
net = Net()
with pytest.raises(TypeError):
net()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save