pull/7173/head
wanyiming 4 years ago
parent 7af0d3374f
commit e2b8810413

@ -339,6 +339,12 @@ def get_broadcast_matmul_shape(x_shape, y_shape):
@constexpr
def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2):
"""check col and row equal"""
if len(x1_shape) == 1:
transpose_x1 = False
x1_shape = (1,) + x1_shape
if len(x2_shape) == 1:
transpose_x2 = False
x2_shape = x2_shape + (1,)
x1_last = x1_shape[-2:]
x2_last = x2_shape[-2:]
x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
@ -348,27 +354,48 @@ def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2):
+ f'the row of matrix dimensions of x2, but got {x1_col} and {x2_row}.')
@constexpr
def matmul_op_select(x1_shape, x2_shape, transpose_x1, transpose_x2):
"""select matmul op"""
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
if x1_dim == 1 and x2_dim == 1:
matmul_op = P.Mul()
elif x1_dim <= 2 and x2_dim <= 2:
transpose_x1 = False if x1_dim == 1 else transpose_x1
transpose_x2 = False if x2_dim == 1 else transpose_x2
matmul_op = P.MatMul(transpose_x1, transpose_x2)
elif x1_dim == 1 and x2_dim > 2:
matmul_op = P.BatchMatMul(False, transpose_x2)
elif x1_dim > 2 and x2_dim == 1:
matmul_op = P.BatchMatMul(transpose_x1, False)
else:
matmul_op = P.BatchMatMul(transpose_x1, transpose_x2)
return matmul_op
class MatMul(Cell):
"""
Multiplies matrix `x1` by matrix `x2`.
The rank of input tensors must be not less than `2`. The none-matrix dimensions(batch) of inputs
will be broadcasted and must be broadcastable.
- If both x1 and x2 are 1-dimensional, the dot product is returned.
- If the dimensions of x1 and x2 are all not greater than 2, the matrix-matrix product will be returned. Note if
one of 'x1' and 'x2' is 1-dimensional, the argument will first be expanded to 2 dimension. After the matrix
multiply, the expanded dimension will be removed.
- If at least one of x1 and x2 is N-dimensional (N>2), the none-matrix dimensions(batch) of inputs will be
broadcasted and must be broadcastable. Note if one of 'x1' and 'x2' is 1-dimensional, the argument will first be
expanded to 2 dimension and then the none-matrix dimensions will be broadcasted. After the matrix multiply, the
expanded dimension will be removed.
Args:
transpose_x1 (bool): If true, `a` is transposed before multiplication. Default: False.
transpose_x2 (bool): If true, `b` is transposed before multiplication. Default: False.
Inputs:
- **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*A, N, C)`,
where :math:`*A` represents the batch size of `x1` which can be multidimensional.
If `transpose_a` is True, its shape must be :math:`(*A, N, C)` after transposing.
- **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`,
where :math:`*B` represents the batch size of `x2` which can be multidimensional.
If `transpose_b` is True, its shape must be :math:`(*B, C, M)` after transposing.
- **input_x1** (Tensor) - The first tensor to be multiplied.
- **input_x2** (Tensor) - The second tensor to be multiplied.
Outputs:
Tensor, the shape of the output tensor is :math:`(*L, N, M)`. :math:`*L` is the batch size after broadcasting.
Tensor, the shape of the output tensor depends on the dimension of input tensors.
Examples:
>>> net = nn.MatMul()
@ -387,13 +414,26 @@ class MatMul(Cell):
self.transpose_x1 = transpose_x1
self.transpose_x2 = transpose_x2
self.shape_op = P.Shape()
self.matmul_op = P.MatMul(self.transpose_x1, self.transpose_x2)
self.batch_matmul_op = P.BatchMatMul(self.transpose_x1, self.transpose_x2)
self.expand_op = P.ExpandDims()
self.squeeze_left_op = P.Squeeze(-2)
self.squeeze_right_op = P.Squeeze(-1)
self.reduce_sum_op = P.ReduceSum(keep_dims=False)
def construct(self, x1, x2):
x1_shape = self.shape_op(x1)
x2_shape = self.shape_op(x2)
check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
matmul_op = matmul_op_select(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2)
x1_dim, x2_dim = len(x1_shape), len(x2_shape)
if x1_dim == x2_dim and x2_dim == 1:
return self.reduce_sum_op(matmul_op(x1, x2), -1)
if x1_dim == 1:
x1 = self.expand_op(x1, 0)
x1_shape = self.shape_op(x1)
if x2_dim == 1:
x2 = self.expand_op(x2, 1)
x2_shape = self.shape_op(x2)
x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape)
x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape)
@ -402,8 +442,12 @@ class MatMul(Cell):
x1 = x1_broadcast_to(x1)
if x2_broadcast_shape != x2_shape:
x2 = x2_broadcast_to(x2)
if len(x1_broadcast_shape) == 2:
matmul_broadcast = self.matmul_op(x1, x2)
else:
matmul_broadcast = self.batch_matmul_op(x1, x2)
matmul_broadcast = matmul_op(x1, x2)
if x1_dim == 1:
matmul_broadcast = self.squeeze_left_op(matmul_broadcast)
if x2_dim == 1:
matmul_broadcast = self.squeeze_right_op(matmul_broadcast)
return matmul_broadcast

@ -54,3 +54,51 @@ def test_x1_3D_transpose_x1_True_x2_3D_transpose_x2_True():
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 6, 4)
def test_x1_1D_x2_1D():
x1 = np.random.randn(4).astype(np.float32)
x2 = np.random.randn(4).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == ()
def test_x1_1D_x2_3D():
x1 = np.random.randn(4).astype(np.float32)
x2 = np.random.randn(2, 4, 5).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 5)
def test_x1_3D_x2_1D():
x1 = np.random.randn(2, 4, 5).astype(np.float32)
x2 = np.random.randn(5).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 4)
def test_x1_1D_transpose_x1_True_x2_3D():
x1 = np.random.randn(4).astype(np.float32)
x2 = np.random.randn(2, 4, 5).astype(np.float32)
transpose_x1 = True
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 5)
def test_x1_3D_x2_1D_transpose_x2_True():
x1 = np.random.randn(2, 4, 5).astype(np.float32)
x2 = np.random.randn(5).astype(np.float32)
transpose_x1 = False
transpose_x2 = True
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 4)

@ -0,0 +1,104 @@
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self, transpose_x1, transpose_x2):
super(Net, self).__init__()
self.matmul = nn.MatMul(transpose_x1, transpose_x2)
def construct(self, x1, x2):
return self.matmul(x1, x2)
def test_x1_2D_x2_3D():
x1 = np.random.randn(16, 64).astype(np.float32)
x2 = np.random.randn(32, 64, 20).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (32, 16, 20)
def test_x1_4D_x2_3D_transpose_x2_True():
x1 = np.random.randn(3, 2, 3, 4).astype(np.float32)
x2 = np.random.randn(1, 5, 4).astype(np.float32)
transpose_x1 = False
transpose_x2 = True
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (3, 2, 3, 5)
def test_x1_3D_transpose_x1_True_x2_2D():
x1 = np.random.randn(2, 3, 4).astype(np.float32)
x2 = np.random.randn(3, 4).astype(np.float32)
transpose_x1 = True
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 4, 4)
def test_x1_3D_transpose_x1_True_x2_3D_transpose_x2_True():
x1 = np.random.randn(2, 5, 6).astype(np.float32)
x2 = np.random.randn(2, 4, 5).astype(np.float32)
transpose_x1 = True
transpose_x2 = True
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 6, 4)
def test_x1_1D_x2_1D():
x1 = np.random.randn(4).astype(np.float32)
x2 = np.random.randn(4).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == ()
def test_x1_1D_x2_3D():
x1 = np.random.randn(4).astype(np.float32)
x2 = np.random.randn(2, 4, 5).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 5)
def test_x1_3D_x2_1D():
x1 = np.random.randn(2, 4, 5).astype(np.float32)
x2 = np.random.randn(5).astype(np.float32)
transpose_x1 = False
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 4)
def test_x1_1D_transpose_x1_True_x2_3D():
x1 = np.random.randn(4).astype(np.float32)
x2 = np.random.randn(2, 4, 5).astype(np.float32)
transpose_x1 = True
transpose_x2 = False
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 5)
def test_x1_3D_x2_1D_transpose_x2_True():
x1 = np.random.randn(2, 4, 5).astype(np.float32)
x2 = np.random.randn(5).astype(np.float32)
transpose_x1 = False
transpose_x2 = True
net = Net(transpose_x1, transpose_x2)
output = net(Tensor(x1), Tensor(x2))
assert output.shape == (2, 4)
Loading…
Cancel
Save