You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
4.5 KiB
142 lines
4.5 KiB
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""internal utility functions"""
|
|
|
|
import numpy as onp
|
|
|
|
from ..common import Tensor
|
|
from ..ops import functional as F
|
|
from ..common import dtype as mstype
|
|
|
|
from .utils_const import _tile_size, _add_unit_axes, _raise_type_error
|
|
|
|
|
|
def _deep_list(array_like):
|
|
"""convert nested tuple/list mixtures to pure nested list"""
|
|
if isinstance(array_like, (list, tuple)):
|
|
return list(map(_deep_list, array_like))
|
|
return array_like
|
|
|
|
|
|
def _deep_tensor_to_nparray(array_like):
|
|
"""
|
|
convert a nested list of tensor to nested list of np_array.
|
|
|
|
Args:
|
|
array_like(list(tensor)): In any format of nested lists that may contain
|
|
tensors.
|
|
|
|
Returns:
|
|
array_like(list(np_array)): Formatted array that can be directly processed
|
|
by numpy.array(), with all tensor elements converted to numpy_array.
|
|
"""
|
|
# Recursively check whether each element is a tensor or not, if is tensor,
|
|
# convert it to a numpy array in place
|
|
if isinstance(array_like, Tensor):
|
|
return array_like.asnumpy()
|
|
|
|
if isinstance(array_like, list):
|
|
for idx, value in enumerate(array_like):
|
|
array_like[idx] = _deep_tensor_to_nparray(value)
|
|
|
|
return array_like
|
|
|
|
|
|
def _check_input_for_asarray(array_like):
|
|
"""check whether array_like argument is a valid type for np.asarray conversion"""
|
|
if not isinstance(array_like, (Tensor, list, tuple, int, float, bool, onp.ndarray)):
|
|
_raise_type_error("input data must be `int`, `float`, `bool`, `Tensor`, `list`, `tuple`" + \
|
|
"or numpy.ndarray, but got ", array_like)
|
|
|
|
|
|
def _is_scalar(shape):
|
|
"""check whether input shape is a scalar"""
|
|
return F.shape_mul(shape) == 1
|
|
|
|
|
|
def _convert_list_tensor_to_tuple_tensor(list_of_tensor):
|
|
"""Convert a list of tensor to a tuple of tensor"""
|
|
if isinstance(list_of_tensor, list):
|
|
tuple_of_tensor = ()
|
|
for tensor in list_of_tensor:
|
|
tuple_of_tensor += (tensor,)
|
|
return tuple_of_tensor
|
|
return list_of_tensor
|
|
|
|
|
|
def _expand(x, ndim, axis=0):
|
|
"""Expand x to ndim from axis, which can be 0 or -1."""
|
|
shape = _add_unit_axes(F.shape(x), ndim, axis == -1)
|
|
return F.reshape(x, shape)
|
|
|
|
|
|
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
|
"""Broadcasts x from shape_cur to shape_to."""
|
|
size = _tile_size(shape_cur, shape_to, ndim_to)
|
|
return F.tile(x, size)
|
|
|
|
|
|
def _broadcast_to_shape(x, shape):
|
|
"""Broadcasts x from current shape to shape"""
|
|
ndim_to = len(shape)
|
|
x = _expand(x, ndim_to)
|
|
return _broadcast_to(x, F.shape(x), shape, ndim_to)
|
|
|
|
|
|
def _get_size(x, axis=None):
|
|
"""Get the number of elements along the given axis of tensor x."""
|
|
if axis is None or F.tuple_len(axis) == 0:
|
|
axis = F.make_range(x.ndim)
|
|
nums = 1
|
|
for ax in axis:
|
|
nums *= x.shape[ax]
|
|
return nums
|
|
|
|
|
|
def _check_input_tensor(*tensors):
|
|
for tensor in tensors:
|
|
if not isinstance(tensor, Tensor):
|
|
_raise_type_error('expect Tensor, but got ', F.typeof(tensor))
|
|
return True
|
|
|
|
|
|
def _convert_64_to_32(tensor):
|
|
"""Convert tensor with float64/int64 types to float32/int32."""
|
|
if tensor.dtype == mstype.float64:
|
|
return tensor.astype("float32")
|
|
if tensor.dtype == mstype.int64:
|
|
return tensor.astype("int32")
|
|
return tensor
|
|
|
|
|
|
def _get_dtype_from_scalar(*input_numbers):
|
|
"""
|
|
Get the final dtype from series of input numbers, compared with F.typeof, we
|
|
return int32/float32 for python int/float instead.
|
|
"""
|
|
bool_flag = True
|
|
int_flag = True
|
|
for number in input_numbers:
|
|
if number is not None:
|
|
if not isinstance(number, bool):
|
|
bool_flag = False
|
|
if not isinstance(number, int):
|
|
int_flag = False
|
|
if bool_flag:
|
|
return mstype.bool_
|
|
if int_flag:
|
|
return mstype.int32
|
|
return mstype.float32
|