|
|
|
@ -22,15 +22,14 @@ from ..ops.primitive import constexpr
|
|
|
|
|
from ..nn.layer.basic import tril as nn_tril
|
|
|
|
|
from ..nn.layer.basic import triu as nn_triu
|
|
|
|
|
from .._c_expression import Tensor as Tensor_
|
|
|
|
|
from .._c_expression.typing import Float
|
|
|
|
|
|
|
|
|
|
from .utils import _check_input_for_asarray, _deep_list, _deep_tensor_to_nparray, \
|
|
|
|
|
_broadcast_to_shape, _check_input_tensor, _convert_64_to_32, _get_dtype_from_scalar, \
|
|
|
|
|
_expand
|
|
|
|
|
from .utils_const import _raise_value_error, _empty, _check_axis_valid, _max, _min, \
|
|
|
|
|
_check_same_type, _is_shape_empty, _check_shape, _check_dtype, _tile_size, _abs, \
|
|
|
|
|
_raise_type_error, _expanded_shape, _tuple_getitem, _check_is_float, _iota, \
|
|
|
|
|
_type_convert, _canonicalize_axis, _list_comprehensions, _ceil
|
|
|
|
|
_raise_type_error, _expanded_shape, _check_is_float, _iota, _type_convert, \
|
|
|
|
|
_canonicalize_axis, _list_comprehensions, _ceil, _tuple_getitem, _tuple_slice
|
|
|
|
|
from .array_ops import transpose, ravel, concatenate, broadcast_arrays, reshape, broadcast_to
|
|
|
|
|
from .dtypes import nan
|
|
|
|
|
|
|
|
|
@ -503,9 +502,9 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
|
|
|
|
|
start, stop = broadcast_arrays(start, stop)
|
|
|
|
|
axis = _canonicalize_axis(axis, start.ndim+1)
|
|
|
|
|
bounds_shape = start.shape
|
|
|
|
|
bounds_shape = bounds_shape[:axis] + (1,) + bounds_shape[axis:]
|
|
|
|
|
bounds_shape = _tuple_slice(bounds_shape, None, axis) + (1,) + _tuple_slice(bounds_shape, axis, None)
|
|
|
|
|
iota_shape = _list_comprehensions(start.ndim+1, 1, True)
|
|
|
|
|
iota_shape = iota_shape[:axis] + (num,) + iota_shape[axis+1:]
|
|
|
|
|
iota_shape = _tuple_slice(iota_shape, None, axis) + (num,) + _tuple_slice(iota_shape, axis+1, None)
|
|
|
|
|
num_tensor = _type_convert(Tensor, num).astype(mstype.float32)
|
|
|
|
|
div = (num_tensor - 1) if endpoint else num_tensor
|
|
|
|
|
|
|
|
|
@ -1542,7 +1541,7 @@ def diag(v, k=0):
|
|
|
|
|
prod = F.tensor_mul(v, e)
|
|
|
|
|
|
|
|
|
|
cast_type = dtype
|
|
|
|
|
if not isinstance(dtype, Float):
|
|
|
|
|
if not _check_is_float(dtype):
|
|
|
|
|
# reduce sum only supports float types
|
|
|
|
|
cast_type = mstype.float32
|
|
|
|
|
prod = F.cast(prod, cast_type)
|
|
|
|
|