You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
635 lines
23 KiB
635 lines
23 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""math Operations."""
|
|
from itertools import zip_longest
|
|
from collections import deque
|
|
import numpy as np
|
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
from mindspore.common import dtype as mstype
|
|
from mindspore._checkparam import Validator as validator
|
|
from mindspore.ops.primitive import constexpr
|
|
from mindspore.ops import functional as F
|
|
from .. import operations as P
|
|
|
|
# count_nonzero
|
|
|
|
|
|
@constexpr
|
|
def _check_validate_axis(axis, name):
|
|
if isinstance(axis, (tuple, list)):
|
|
for idx, item in enumerate(axis):
|
|
validator.check_value_type("axis[%d]" % idx, item, [int], name)
|
|
axis = validator.check_value_type('axis', axis, [int, tuple, list], name)
|
|
return axis
|
|
|
|
|
|
@constexpr
|
|
def _check_validate_keepdims(keep_dims, name):
|
|
keep_dims = validator.check_value_type('keep_dims', keep_dims, [bool], name)
|
|
return keep_dims
|
|
|
|
|
|
def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
|
|
r"""
|
|
Count number of nonzero elements across axis of input tensor
|
|
|
|
Args:
|
|
x (Tensor): Input data is used to count non-zero numbers.
|
|
axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
|
|
Default: (), reduce all dimensions.
|
|
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
|
|
If false, don't keep these dimensions. Default: False.
|
|
dtype (Union[Number, mstype.bool\_]): The data type of the output tensor. Only constant value is allowed.
|
|
Default: mstype.int32
|
|
|
|
Returns:
|
|
Tensor, number of nonzero element. The data type is dtype.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU``
|
|
|
|
Examples:
|
|
>>> input_x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
|
|
>>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32)
|
|
>>> print(nonzero_num)
|
|
[[3]]
|
|
"""
|
|
|
|
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
|
|
axis = _check_validate_axis(axis, "count_nonzero")
|
|
keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
|
|
const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
|
|
|
|
not_equal = P.NotEqual()
|
|
cast = P.Cast()
|
|
reduce_sum = P.ReduceSum(keep_dims)
|
|
nonzero_bool = not_equal(x, 0)
|
|
# ReduceSum only support float16 or float32 tensor.
|
|
nonzero_val = cast(nonzero_bool, mstype.float16)
|
|
nonzero_num = cast(reduce_sum(nonzero_val, axis), dtype)
|
|
|
|
return nonzero_num
|
|
|
|
# tensor dot
|
|
|
|
|
|
@constexpr
|
|
def _int_to_tuple_conv(axes):
|
|
"""
|
|
Converts ints to tuples in input axes, expected by most validation checks.
|
|
"""
|
|
for x in [0, 1]:
|
|
if isinstance(axes[x], int):
|
|
axes[x] = (axes[x],)
|
|
return axes
|
|
|
|
|
|
@constexpr
|
|
def _check_axes(axes):
|
|
"""
|
|
Check for validity and type of axes passed to function.
|
|
"""
|
|
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
|
|
if not isinstance(axes, int):
|
|
axes = list(axes) # to avoid immutability issues
|
|
if len(axes) != 2:
|
|
raise ValueError("Require two axes inputs, given less")
|
|
axes = _int_to_tuple_conv(axes) # convert before length checks
|
|
if len(axes[0]) != len(axes[1]):
|
|
raise ValueError("Axes have to be the same size/length")
|
|
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
|
|
raise ValueError("Axes cannot have duplicating values")
|
|
return axes
|
|
|
|
|
|
@constexpr
|
|
def _typecheck_input(x1_type, x2_type):
|
|
"""
|
|
Check input tensor types to be valid and confirm they are the same type.
|
|
"""
|
|
const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
|
|
const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
|
|
if x1_type != x2_type:
|
|
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')
|
|
|
|
|
|
@constexpr
|
|
def _axes_int_check(x1_shape, x2_shape, axes):
|
|
"""
|
|
Convert from single int axes to 2d tuple if required
|
|
"""
|
|
if isinstance(axes, int):
|
|
if axes < 0:
|
|
raise ValueError(f"axes must be at least 0 for tensor dot, got {axes}")
|
|
if axes == 0:
|
|
# outer product, no input validation required
|
|
return ([], [])
|
|
if axes > len(x1_shape) or axes > len(x2_shape):
|
|
raise ValueError(
|
|
"Axes value too high for given input arrays dimensions.")
|
|
x1_ind = tuple(range(len(x1_shape))[-1 * axes:])
|
|
x2_ind = tuple(range(len(x2_shape))[:axes])
|
|
axes = tuple((x1_ind, x2_ind))
|
|
axes = _int_to_tuple_conv(axes)
|
|
return axes
|
|
|
|
|
|
@constexpr
|
|
def _validate_axes(x1_shape, x2_shape, axes):
|
|
"""
|
|
Checks for axes having the correct length according to input, for any value in axis
|
|
being out of range with given shape and also checking for compatible axes values
|
|
with given inputs.
|
|
"""
|
|
shapes = [x1_shape, x2_shape]
|
|
|
|
# axis length check
|
|
for ix_input, x_axes in enumerate(axes):
|
|
axes_len = len(x_axes)
|
|
shape_dim_len = len(shapes[ix_input])
|
|
if axes_len > shape_dim_len:
|
|
raise ValueError(f"axes for input: {ix_input + 1} are of length: {axes_len} "
|
|
f"can only be max: {shape_dim_len} due to input shape.")
|
|
|
|
# axis values range check
|
|
for ix_input, x_axes in enumerate(axes):
|
|
comp_shape = shapes[ix_input]
|
|
max_val = len(comp_shape) - 1
|
|
min_val = -1 * len(comp_shape)
|
|
for _, x_value in enumerate(x_axes):
|
|
if not min_val <= x_value <= max_val:
|
|
raise ValueError(f"axes for input: {ix_input + 1} contains index: "
|
|
f"{x_value}, but range is: [{min_val}, {max_val}]")
|
|
|
|
# check axis value with input shape - both ways for axis valid
|
|
invalid_a = False
|
|
invalid_b = False
|
|
for i in range(len(axes[0])): # sizes already validated
|
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][i]]:
|
|
invalid_a = True
|
|
if x1_shape[axes[0][i]] != x2_shape[axes[1][len(axes[0])-1-i]]:
|
|
invalid_b = True
|
|
if invalid_a and invalid_b:
|
|
raise ValueError("Given Axes are incompatible with given input arrays")
|
|
|
|
|
|
@constexpr
|
|
def _calc_new_shape(shape, axes, position=0):
|
|
"""
|
|
Calculate transpose and reshape parameters for input transformations,
|
|
'position' refers to whether tensor is first or second in the op.
|
|
"""
|
|
contraction_axes = tuple(i if i >= 0 else i + len(shape) for i in axes[position])
|
|
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
|
|
free_axes = tuple(i for i in range(len(shape)) if i not in contraction_axes)
|
|
free_dims = tuple(shape[i] for i in free_axes)
|
|
prod_free = int(np.prod(free_dims))
|
|
|
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
|
return new_shape, transpose_perm, free_dims
|
|
|
|
|
|
def tensor_dot(x1, x2, axes):
|
|
"""
|
|
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
|
|
|
|
Contraction allows for the summation of products of elements of `a` and `b` on specified axes.
|
|
The same number of axes must be specified for both x1 and x2, and values must be within range
|
|
of number of dims of both `a` and `b`.
|
|
|
|
Selected dims in both inputs must also match.
|
|
|
|
axes = 0 leads to outer product
|
|
axes = 1 leads to normal matrix multiplication when inputs both 2D.
|
|
axes = 1 is the same as axes = ((1,),(0,) where both `a` and `b` are 2D.
|
|
axes = 2 is the same as axes = ((1,2),(0,1)) where both `a` and `b` are 3D.
|
|
|
|
Inputs:
|
|
- **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32
|
|
- **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32
|
|
- **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or
|
|
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
|
|
automatically picks up last N dims from `a` input shape and first N dims from `b` input shape in order
|
|
as axes for each respectively.
|
|
|
|
Outputs:
|
|
Tensor, the shape of the output tensor is :math:`(N + M)`. Where :math:`N` and :math:`M` are the free axes not
|
|
contracted in both inputs
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
Examples:
|
|
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
|
|
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
|
|
>>> output = C.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
|
|
>>> print(output)
|
|
[[2. 2. 2]
|
|
[2. 2. 2]
|
|
[2. 2. 2]]
|
|
"""
|
|
shape_op = P.Shape()
|
|
reshape_op = P.Reshape()
|
|
transpose_op = P.Transpose()
|
|
matmul_op = P.MatMul(False, False)
|
|
# input validity checks
|
|
x1_shape = shape_op(x1)
|
|
x2_shape = shape_op(x2)
|
|
x1_type = F.dtype(x1)
|
|
x2_type = F.dtype(x2)
|
|
axes = _check_axes(axes)
|
|
_typecheck_input(x1_type, x2_type)
|
|
# input compatibility check & axes format update
|
|
axes = _axes_int_check(x1_shape, x2_shape, axes)
|
|
_validate_axes(x1_shape, x2_shape, axes)
|
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
|
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
|
|
output_shape = x1_ret + x2_ret # combine free axes from both inputs
|
|
# run tensor_dot op
|
|
x1_transposed = transpose_op(x1, x1_transpose_fwd)
|
|
x2_transposed = transpose_op(x2, x2_transpose_fwd)
|
|
x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)
|
|
x2_reshaped = reshape_op(x2_transposed, x2_reshape_fwd)
|
|
mul_result = matmul_op(x1_reshaped, x2_reshaped)
|
|
final_result = reshape_op(mul_result, output_shape)
|
|
return final_result
|
|
|
|
|
|
@constexpr
|
|
def _check_invalid_input(x1_shape, x2_shape):
|
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
raise ValueError('C.dot inputs x1, x2 should has dimension >= 2,'
|
|
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')
|
|
|
|
|
|
@constexpr
|
|
def _get_transpose_shape(x2_shape):
|
|
x2_shape_range = tuple(range(len(x2_shape)))
|
|
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
|
return x2_shape_transpose
|
|
|
|
|
|
def dot(x1, x2):
|
|
"""
|
|
Computation a dot product between samples in two tensors.
|
|
|
|
Inputs:
|
|
- **x1** (Tensor) - First tensor in Dot op with datatype float16 or float32
|
|
- **x2** (Tensor) - Second tensor in Dot op with datatype float16 or float32
|
|
|
|
Outputs:
|
|
Tensor, dot product of x1 and x2.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
Examples:
|
|
>>> input_x1 = Tensor(np.ones(shape=[2, 3]), mindspore.float32)
|
|
>>> input_x2 = Tensor(np.ones(shape=[1, 3, 2]), mindspore.float32)
|
|
>>> output = C.dot(input_x1, input_x2)
|
|
>>> print(output)
|
|
[[[3. 3.]]
|
|
[[3. 3.]]]
|
|
"""
|
|
shape_op = P.Shape()
|
|
reshape_op = P.Reshape()
|
|
transpose_op = P.Transpose()
|
|
matmul_op = P.MatMul(False, False)
|
|
x1_shape = shape_op(x1)
|
|
x2_shape = shape_op(x2)
|
|
_check_invalid_input(x1_shape, x2_shape)
|
|
|
|
if len(x1_shape) > 2 or len(x2_shape) > 2:
|
|
x2_shape_transpose = _get_transpose_shape(x2_shape)
|
|
x2_transpose = transpose_op(x2, x2_shape_transpose)
|
|
x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
|
|
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
|
|
mul_result = matmul_op(x1_reshape, x2_reshape)
|
|
return reshape_op(mul_result, x1_shape[:-1] + x2_shape[:-2] + x2_shape[-1:])
|
|
return matmul_op(x1, x2)
|
|
|
|
|
|
@constexpr
|
|
def _get_batch_size(x1_shape, x2_shape):
|
|
"""
|
|
Get batch sizes from two inputs
|
|
"""
|
|
if len(x1_shape) < 2 or len(x2_shape) < 2:
|
|
raise ValueError("Require both inputs with rank >= 2.")
|
|
return x1_shape[0], x2_shape[0]
|
|
|
|
|
|
@constexpr
|
|
def _check_axes_for_batch_dot(x1_shape, x2_shape, axes):
|
|
"""
|
|
Check whether axes are valid and cast axes from tuple to list
|
|
"""
|
|
if axes is None:
|
|
if len(x2_shape) == 2:
|
|
axes = [len(x1_shape) - 1, len(x2_shape) - 1]
|
|
else:
|
|
axes = [len(x1_shape) - 1, len(x2_shape) - 2]
|
|
|
|
if isinstance(axes, (list, tuple)):
|
|
if 0 in axes:
|
|
raise ValueError("Batch dim cannot be used as in axes.")
|
|
if len(axes) != 2:
|
|
raise ValueError("Require two axes inputs, given less")
|
|
if isinstance(axes, tuple):
|
|
axes = list(axes)
|
|
for sub_axes in axes:
|
|
if isinstance(sub_axes, (list, tuple)):
|
|
raise ValueError("Require dimension to be in any of those: None, int, (int, int).")
|
|
# Reverse if axis < 0
|
|
if axes[0] < 0:
|
|
axes[0] += len(x1_shape)
|
|
if axes[1] < 0:
|
|
axes[1] += len(x2_shape)
|
|
elif isinstance(axes, int):
|
|
if axes == 0:
|
|
raise ValueError("Batch dim cannot be used as in axes.")
|
|
if axes < 0:
|
|
axes = [axes + len(x1_shape), axes + len(x2_shape)]
|
|
elif axes > len(x1_shape) or axes > len(x2_shape):
|
|
raise ValueError(
|
|
"Axes value too high for given input arrays dimensions.")
|
|
else:
|
|
axes = [axes, axes]
|
|
else:
|
|
raise ValueError(
|
|
"Axes type must be one of those: int, tuple(int), list(int).")
|
|
return axes
|
|
|
|
|
|
@constexpr
|
|
def _calc_new_shape_batchdot(shape, axes, position=0):
|
|
"""
|
|
Calculate transpose and reshape parameters for input transformations,
|
|
'position' refers to whether tensor is first or second in the op.
|
|
"""
|
|
axis = axes[position]
|
|
contraction_axes = tuple([axis])
|
|
prod_contraction = int(np.prod([shape[i] for i in contraction_axes]))
|
|
free_axes = tuple(i for i in range(1, len(shape)) if i not in contraction_axes)
|
|
free_dims = tuple(shape[i] for i in free_axes)
|
|
prod_free = int(np.prod(free_dims))
|
|
|
|
transpose_perm = contraction_axes + free_axes if position else free_axes + contraction_axes
|
|
transpose_perm = tuple([0]) + transpose_perm
|
|
new_shape = (prod_contraction, prod_free) if position else (prod_free, prod_contraction)
|
|
new_shape = tuple([shape[0]]) + new_shape
|
|
return new_shape, transpose_perm, free_dims
|
|
|
|
|
|
@constexpr
|
|
def _check_batch_size(x1_batch_size, x2_batch_size):
|
|
"""
|
|
Check whether batch size of two inputs are the same
|
|
"""
|
|
if x1_batch_size != x2_batch_size:
|
|
raise ValueError("Require both inputs with the same batch sizes.")
|
|
|
|
@constexpr
|
|
def _get_output_shape(batch_size, x1_ret, x2_ret):
|
|
"""
|
|
Compute output shape for batch dot
|
|
"""
|
|
output_shape = tuple([batch_size]) + x1_ret + x2_ret
|
|
return output_shape
|
|
|
|
def batch_dot(x1, x2, axes=None):
|
|
"""
|
|
Computation of batch dot product between samples in two tensors containing batch dims.
|
|
|
|
Inputs:
|
|
- **x1** (Tensor) - First tensor in Batch Dot op with datatype float16 or float32
|
|
- **x2** (Tensor) - Second tensor in Batch Dot op with datatype float16 or float32. x2's datatype should
|
|
be same as x1's.
|
|
- **axes** (Union[int, tuple(int), list(int)]) - Single value or tuple/list of length 2 with dimensions
|
|
specified for `a` and `b` each. If single value `N` passed, automatically picks up last N dims from
|
|
`a` input shape and last N dims from `b` input shape in order as axes for each respectively.
|
|
|
|
Outputs:
|
|
Tensor, batch dot product of x1 and x2.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
Examples:
|
|
>>> input_x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
|
|
>>> input_x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
|
|
>>> axes = (-1, -2)
|
|
>>> output = C.batch_dot(input_x1, input_x2, axes)
|
|
>>> print(output)
|
|
[[[3. 3.]
|
|
[3. 3.]]
|
|
[[3. 3.]
|
|
[3. 3.]]]
|
|
"""
|
|
|
|
transpose_op = P.Transpose()
|
|
batch_matmul_op = P.BatchMatMul()
|
|
squeeze_one_op = P.Squeeze(1)
|
|
squeeze_minus_one_op = P.Squeeze(-1)
|
|
# input validity checks
|
|
x1_shape = F.shape(x1)
|
|
x2_shape = F.shape(x2)
|
|
x1_dim_num = len(x1_shape)
|
|
x2_dim_num = len(x2_shape)
|
|
x1_type = F.dtype(x1)
|
|
x2_type = F.dtype(x2)
|
|
|
|
x1_batch_size, x2_batch_size = _get_batch_size(x1_shape, x2_shape)
|
|
|
|
_typecheck_input(x1_type, x2_type)
|
|
_check_batch_size(x1_batch_size, x2_batch_size)
|
|
axes = _check_axes_for_batch_dot(x1_shape, x2_shape, axes)
|
|
|
|
if x1_dim_num == 2:
|
|
x1 = F.expand_dims(x1, 1)
|
|
axes[0] += 1
|
|
if x2_dim_num == 2:
|
|
x2 = F.expand_dims(x2, 2)
|
|
|
|
x1_shape = F.shape(x1)
|
|
x2_shape = F.shape(x2)
|
|
|
|
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape_batchdot(x1_shape, axes, 0)
|
|
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape_batchdot(x2_shape, axes, 1)
|
|
output_shape = _get_output_shape(x1_batch_size, x1_ret, x2_ret)
|
|
|
|
x1_transposed = transpose_op(x1, x1_transpose_fwd)
|
|
x2_transposed = transpose_op(x2, x2_transpose_fwd)
|
|
x1_reshaped = F.reshape(x1_transposed, x1_reshape_fwd)
|
|
x2_reshaped = F.reshape(x2_transposed, x2_reshape_fwd)
|
|
|
|
# Batch matmal op part
|
|
mul_result = batch_matmul_op(x1_reshaped, x2_reshaped)
|
|
|
|
final_result = F.reshape(mul_result, output_shape)
|
|
|
|
# if the original dims are expanded, restore them from 3 to 2
|
|
if x1_dim_num == 2:
|
|
final_result = squeeze_one_op(final_result)
|
|
elif x2_dim_num == 2:
|
|
final_result = squeeze_minus_one_op(final_result)
|
|
|
|
return final_result
|
|
|
|
@constexpr
|
|
def _check_same_type(dtype1, dtype2):
|
|
return dtype1 == dtype2
|
|
|
|
@constexpr
|
|
def _max(*args):
|
|
"""Returns the maximum value."""
|
|
return max(*args)
|
|
|
|
@constexpr
|
|
def _min(*args):
|
|
"""Returns the minimum value."""
|
|
return min(*args)
|
|
|
|
@constexpr
|
|
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
|
"""Infers the shape of the last two dimensions after performing matmul."""
|
|
shape_rem = []
|
|
if ndim1 >= 2:
|
|
shape_rem.append(shape1[-2])
|
|
if transpose_b:
|
|
if ndim2 >= 2:
|
|
shape_rem.append(shape2[-2])
|
|
else:
|
|
if ndim1 >= 1:
|
|
shape_rem.append(shape2[-1])
|
|
return tuple(shape_rem)
|
|
|
|
@constexpr
|
|
def _check_matmul_shapes(shape1, shape2):
|
|
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
|
|
ndim1, ndim2 = len(shape1), len(shape2)
|
|
if ndim1 < 1 or ndim2 < 1:
|
|
raise ValueError('input operands must have at least 1 dimension')
|
|
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
|
|
raise ValueError(f'mismatch in core dimension of input operands (size '
|
|
f'{shape1[-1]} is different from {shape2[-2]})')
|
|
shape_out = deque()
|
|
for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
|
|
max_size = max(items)
|
|
if any(item not in (1, max_size) for item in items):
|
|
raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}')
|
|
shape_out.appendleft(max_size)
|
|
return tuple(shape_out)
|
|
|
|
@constexpr
|
|
def _tile_size(shape, out_shape, ndim):
|
|
"""Returns tile_size such that shape*tile_size = out_shape"""
|
|
size = [1]*ndim
|
|
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
|
if i != j:
|
|
size[idx] = j
|
|
return tuple(size)
|
|
|
|
@constexpr
|
|
def _check_need_broadcast(shape1, shape2):
|
|
"""Returns True if broadcast is necessary for batchmatmul."""
|
|
return shape1[:-2] != shape2[:-2]
|
|
|
|
def _expand(x, ndim):
|
|
"""Expand x to ndim from axis, which can be 0 or -1."""
|
|
while F.rank(x) < ndim:
|
|
x = F.expand_dims(x, 0)
|
|
return x
|
|
|
|
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
|
"""Broadcasts x from shape_cur to shape_to."""
|
|
size = _tile_size(shape_cur, shape_to, ndim_to)
|
|
return F.tile(x, size)
|
|
|
|
def matmul(x1, x2, dtype=None):
|
|
"""
|
|
Returns the matrix product of two arrays.
|
|
|
|
Note:
|
|
Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are
|
|
not supported.
|
|
On GPU, the supported dtypes are np.float16 and np.float32.
|
|
On CPU, the supported dtypes are np.float16 and np.float32.
|
|
|
|
Args:
|
|
x1 (Tensor): Input tensor, scalar not allowed.
|
|
x2 (Tensor): Input tensor, scalar not allowed.
|
|
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
|
output Tensor.
|
|
|
|
Returns:
|
|
Tensor or scalar, the matrix product of the inputs. This is a scalar only
|
|
when both `x1`, `x2` are 1-d vectors.
|
|
|
|
Raises:
|
|
ValueError: If the last dimension of `x1` is not the same size as the
|
|
second-to-last dimension of `x2`, or if a scalar value is passed in.
|
|
|
|
Supported Platforms:
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
Examples:
|
|
>>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32')
|
|
>>> x2 = np.arange(4*5).reshape(4, 5).astype('float32')
|
|
>>> output = np.matmul(x1, x2)
|
|
>>> print(output)
|
|
[[[ 70. 76. 82. 88. 94.]
|
|
[ 190. 212. 234. 256. 278.]
|
|
[ 310. 348. 386. 424. 462.]]
|
|
[[ 430. 484. 538. 592. 646.]
|
|
[ 550. 620. 690. 760. 830.]
|
|
[ 670. 756. 842. 928. 1014.]]]
|
|
"""
|
|
# performs type promotion
|
|
dtype1 = F.dtype(x1)
|
|
dtype2 = F.dtype(x2)
|
|
if not _check_same_type(dtype1, dtype2):
|
|
x1 = x1.astype(mstype.float32)
|
|
x2 = x2.astype(mstype.float32)
|
|
|
|
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
|
|
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
|
|
transpose_b = ndim2_orig == 1
|
|
shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig)
|
|
# infers the shape of the output
|
|
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
|
|
ndim1_orig, ndim2_orig, transpose_b)
|
|
|
|
x1 = _expand(x1, 2)
|
|
x2 = _expand(x2, 2)
|
|
if F.rank(x2) == 2:
|
|
if F.rank(x1) > 2:
|
|
x1 = F.reshape(x1, (-1, shape1_orig[-1]))
|
|
res = P.MatMul(False, transpose_b)(x1, x2)
|
|
else:
|
|
# broadcasts x1.shape[:-2] with x2.shape[:-2]
|
|
ndim_aligned = _max(ndim1_orig, ndim2_orig)
|
|
x1 = _expand(x1, ndim_aligned)
|
|
x2 = _expand(x2, ndim_aligned)
|
|
shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
|
|
x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned)
|
|
x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned)
|
|
res = P.BatchMatMul(False, transpose_b)(x1, x2)
|
|
|
|
if dtype is not None:
|
|
res = res.astype(dtype)
|
|
return F.reshape(res, shape_out)
|