|
|
|
@ -13,6 +13,8 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""math Operations."""
|
|
|
|
|
from itertools import zip_longest
|
|
|
|
|
from collections import deque
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
@ -486,3 +488,147 @@ def batch_dot(x1, x2, axes=None):
|
|
|
|
|
final_result = squeeze_minus_one_op(final_result)
|
|
|
|
|
|
|
|
|
|
return final_result
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_same_type(dtype1, dtype2):
|
|
|
|
|
return dtype1 == dtype2
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _max(*args):
|
|
|
|
|
"""Returns the maximum value."""
|
|
|
|
|
return max(*args)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _min(*args):
|
|
|
|
|
"""Returns the minimum value."""
|
|
|
|
|
return min(*args)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _infer_shape_rem(shape1, shape2, ndim1, ndim2, transpose_b):
|
|
|
|
|
"""Infers the shape of the last two dimensions after performing matmul."""
|
|
|
|
|
shape_rem = []
|
|
|
|
|
if ndim1 >= 2:
|
|
|
|
|
shape_rem.append(shape1[-2])
|
|
|
|
|
if transpose_b:
|
|
|
|
|
if ndim2 >= 2:
|
|
|
|
|
shape_rem.append(shape2[-2])
|
|
|
|
|
else:
|
|
|
|
|
if ndim1 >= 1:
|
|
|
|
|
shape_rem.append(shape2[-1])
|
|
|
|
|
return tuple(shape_rem)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_matmul_shapes(shape1, shape2):
|
|
|
|
|
"""Checks shape1 and shape2 are valid to perform matmul, and returns output shape after broadcasting."""
|
|
|
|
|
ndim1, ndim2 = len(shape1), len(shape2)
|
|
|
|
|
if ndim1 < 1 or ndim2 < 1:
|
|
|
|
|
raise ValueError('input operands must have at least 1 dimension')
|
|
|
|
|
if ndim2 >= 2 and shape1[-1] != shape2[-2]:
|
|
|
|
|
raise ValueError(f'mismatch in core dimension of input operands (size '
|
|
|
|
|
f'{shape1[-1]} is different from {shape2[-2]})')
|
|
|
|
|
shape_out = deque()
|
|
|
|
|
for items in zip_longest(reversed(shape1[:-2]), reversed(shape2[:-2]), fillvalue=1):
|
|
|
|
|
max_size = max(items)
|
|
|
|
|
if any(item not in (1, max_size) for item in items):
|
|
|
|
|
raise ValueError(f'operands could not be broadcast together with shapes {shape1} {shape2}')
|
|
|
|
|
shape_out.appendleft(max_size)
|
|
|
|
|
return tuple(shape_out)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _tile_size(shape, out_shape, ndim):
|
|
|
|
|
"""Returns tile_size such that shape*tile_size = out_shape"""
|
|
|
|
|
size = [1]*ndim
|
|
|
|
|
for idx, (i, j) in enumerate(zip(shape, out_shape)):
|
|
|
|
|
if i != j:
|
|
|
|
|
size[idx] = j
|
|
|
|
|
return tuple(size)
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_need_broadcast(shape1, shape2):
|
|
|
|
|
"""Returns True if broadcast is necessary for batchmatmul."""
|
|
|
|
|
return shape1[:-2] != shape2[:-2]
|
|
|
|
|
|
|
|
|
|
def _expand(x, ndim):
|
|
|
|
|
"""Expand x to ndim from axis, which can be 0 or -1."""
|
|
|
|
|
while F.rank(x) < ndim:
|
|
|
|
|
x = F.expand_dims(x, 0)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _broadcast_to(x, shape_cur, shape_to, ndim_to):
|
|
|
|
|
"""Broadcasts x from shape_cur to shape_to."""
|
|
|
|
|
size = _tile_size(shape_cur, shape_to, ndim_to)
|
|
|
|
|
return F.tile(x, size)
|
|
|
|
|
|
|
|
|
|
def matmul(x1, x2, dtype=None):
|
|
|
|
|
"""
|
|
|
|
|
Returns the matrix product of two arrays.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Numpy arguments `out`, `casting`, `order`, `subok`, `signature`, and `extobj` are
|
|
|
|
|
not supported.
|
|
|
|
|
On GPU, the supported dtypes are np.float16 and np.float32.
|
|
|
|
|
On CPU, the supported dtypes are np.float16 and np.float32.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x1 (Tensor): Input tensor, scalar not allowed.
|
|
|
|
|
x2 (Tensor): Input tensor, scalar not allowed.
|
|
|
|
|
dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the
|
|
|
|
|
output Tensor.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor or scalar, the matrix product of the inputs. This is a scalar only
|
|
|
|
|
when both `x1`, `x2` are 1-d vectors.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If the last dimension of `x1` is not the same size as the
|
|
|
|
|
second-to-last dimension of `x2`, or if a scalar value is passed in.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU`` ``CPU``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> x1 = np.arange(2*3*4).reshape(2, 3, 4).astype('float32')
|
|
|
|
|
>>> x2 = np.arange(4*5).reshape(4, 5).astype('float32')
|
|
|
|
|
>>> output = np.matmul(x1, x2)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[[[ 70. 76. 82. 88. 94.]
|
|
|
|
|
[ 190. 212. 234. 256. 278.]
|
|
|
|
|
[ 310. 348. 386. 424. 462.]]
|
|
|
|
|
[[ 430. 484. 538. 592. 646.]
|
|
|
|
|
[ 550. 620. 690. 760. 830.]
|
|
|
|
|
[ 670. 756. 842. 928. 1014.]]]
|
|
|
|
|
"""
|
|
|
|
|
# performs type promotion
|
|
|
|
|
dtype1 = F.dtype(x1)
|
|
|
|
|
dtype2 = F.dtype(x2)
|
|
|
|
|
if not _check_same_type(dtype1, dtype2):
|
|
|
|
|
x1 = x1.astype(mstype.float32)
|
|
|
|
|
x2 = x2.astype(mstype.float32)
|
|
|
|
|
|
|
|
|
|
ndim1_orig, ndim2_orig = F.rank(x1), F.rank(x2)
|
|
|
|
|
shape1_orig, shape2_orig = F.shape(x1), F.shape(x2)
|
|
|
|
|
transpose_b = ndim2_orig == 1
|
|
|
|
|
shape_backbone = _check_matmul_shapes(shape1_orig, shape2_orig)
|
|
|
|
|
# infers the shape of the output
|
|
|
|
|
shape_out = shape_backbone + _infer_shape_rem(shape1_orig, shape2_orig,
|
|
|
|
|
ndim1_orig, ndim2_orig, transpose_b)
|
|
|
|
|
|
|
|
|
|
x1 = _expand(x1, 2)
|
|
|
|
|
x2 = _expand(x2, 2)
|
|
|
|
|
if F.rank(x2) == 2:
|
|
|
|
|
if F.rank(x1) > 2:
|
|
|
|
|
x1 = F.reshape(x1, (-1, shape1_orig[-1]))
|
|
|
|
|
res = P.MatMul(False, transpose_b)(x1, x2)
|
|
|
|
|
else:
|
|
|
|
|
# broadcasts x1.shape[:-2] with x2.shape[:-2]
|
|
|
|
|
ndim_aligned = _max(ndim1_orig, ndim2_orig)
|
|
|
|
|
x1 = _expand(x1, ndim_aligned)
|
|
|
|
|
x2 = _expand(x2, ndim_aligned)
|
|
|
|
|
shape1_aligned, shape2_aligned = F.shape(x1), F.shape(x2)
|
|
|
|
|
x1 = _broadcast_to(x1, shape1_aligned[:-2], shape_backbone, ndim_aligned)
|
|
|
|
|
x2 = _broadcast_to(x2, shape2_aligned[:-2], shape_backbone, ndim_aligned)
|
|
|
|
|
res = P.BatchMatMul(False, transpose_b)(x1, x2)
|
|
|
|
|
|
|
|
|
|
if dtype is not None:
|
|
|
|
|
res = res.astype(dtype)
|
|
|
|
|
return F.reshape(res, shape_out)
|
|
|
|
|