|
|
|
@ -611,9 +611,9 @@ class CumProd(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
class MatMul(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|
Multiplies matrix `a` by matrix `b`.
|
|
|
|
|
Multiplies matrix `a` and matrix `b`.
|
|
|
|
|
|
|
|
|
|
The rank of input tensors must be `2`.
|
|
|
|
|
The rank of input tensors must equal to `2`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
transpose_a (bool): If True, `a` is transposed before multiplication. Default: False.
|
|
|
|
@ -629,10 +629,10 @@ class MatMul(PrimitiveWithInfer):
|
|
|
|
|
Tensor, the shape of the output tensor is :math:`(N, M)`.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
|
|
|
|
|
>>> input_y = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
|
|
|
|
|
>>> input_x1 = Tensor(np.ones(shape=[1, 3]), mindspore.float32)
|
|
|
|
|
>>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32)
|
|
|
|
|
>>> matmul = P.MatMul()
|
|
|
|
|
>>> output = matmul(input_x, input_y)
|
|
|
|
|
>>> output = matmul(input_x1, input_x2)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
@ -643,42 +643,44 @@ class MatMul(PrimitiveWithInfer):
|
|
|
|
|
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
|
|
|
|
|
self.add_prim_attr("io_format", "ND")
|
|
|
|
|
|
|
|
|
|
def check_shape_size(self, x, y):
|
|
|
|
|
if len(x) != 2 or len(y) != 2:
|
|
|
|
|
raise ValueError('MatMul input x, y should be the same dimension size and should be '
|
|
|
|
|
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
|
|
|
|
|
def check_shape_size(self, x1, x2):
|
|
|
|
|
if len(x1) != 2 or len(x2) != 2:
|
|
|
|
|
raise ValueError('P.MatMul inputs x1, x2 should has the same dimension size and '
|
|
|
|
|
+ f'equal to 2, while x1 size is ({len(x1)}) and x2 size is ({len(x2)}).')
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x, y):
|
|
|
|
|
self.check_shape_size(x, y)
|
|
|
|
|
def infer_shape(self, x1, x2):
|
|
|
|
|
self.check_shape_size(x1, x2)
|
|
|
|
|
cls_name = self.name
|
|
|
|
|
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
|
|
|
|
for i in range(len(x) - 2):
|
|
|
|
|
if x[i] != y[i]:
|
|
|
|
|
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, while x is {x[i]}, y is {y[i]}')
|
|
|
|
|
|
|
|
|
|
# validate whether last two dims satifing matrix multiply
|
|
|
|
|
x_last = x[-2:]
|
|
|
|
|
y_last = y[-2:]
|
|
|
|
|
|
|
|
|
|
x_col = x_last[not self.transpose_a] # x_col = x_last[1] if (not transpose_a) else x_last[0]
|
|
|
|
|
y_row = y_last[self.transpose_b] # y_row = y_last[0] if (not transpose_b) else y_last[1]
|
|
|
|
|
if x_col != y_row:
|
|
|
|
|
for i in range(len(x1) - 2):
|
|
|
|
|
if x1[i] != x2[i]:
|
|
|
|
|
raise ValueError(f'For \'{cls_name}\' shape in dim[{i}] not the same, '
|
|
|
|
|
+ f'while x1 is {x1[i]}, x2 is {x2[i]}')
|
|
|
|
|
|
|
|
|
|
# validate whether last two dims satisfying matrix multiply
|
|
|
|
|
x1_last = x1[-2:]
|
|
|
|
|
x2_last = x2[-2:]
|
|
|
|
|
# x1_col = x1_last[1] if (not transpose_a) else x1_last[0]
|
|
|
|
|
x1_col = x1_last[not self.transpose_a]
|
|
|
|
|
# x2_row = x2_last[0] if (not transpose_b) else x2_last[1]
|
|
|
|
|
x2_row = x2_last[self.transpose_b]
|
|
|
|
|
if x1_col != x2_row:
|
|
|
|
|
raise ValueError(f'For \'{cls_name}\' evaluator shapes of inputs can not do this operator,'
|
|
|
|
|
+ f' got {x_col} and {y_row}, with x shape {x}(transpose_a={self.transpose_a})'
|
|
|
|
|
+ f', y shape {y}(transpose_b={self.transpose_b}).')
|
|
|
|
|
+ f' got {x1_col} and {x2_row}, with x1 shape {x1}(transpose_a={self.transpose_a})'
|
|
|
|
|
+ f', x2 shape {x2}(transpose_b={self.transpose_b}).')
|
|
|
|
|
# set attribute
|
|
|
|
|
self.add_prim_attr('transpose_x1', self.transpose_a)
|
|
|
|
|
self.add_prim_attr('transpose_x2', self.transpose_b)
|
|
|
|
|
|
|
|
|
|
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
|
|
|
|
|
ret_dims = x1[: -2] + [x1_last[self.transpose_a], x2_last[not self.transpose_b]]
|
|
|
|
|
return ret_dims
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x, y):
|
|
|
|
|
args = {"x": x, "y": y}
|
|
|
|
|
def infer_dtype(self, x1, x2):
|
|
|
|
|
args = {"x1": x1, "x2": x2}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
|
|
|
|
if x.element_type() == mstype.int8:
|
|
|
|
|
if x1.element_type() == mstype.int8:
|
|
|
|
|
return mstype.tensor_type(mstype.int32)
|
|
|
|
|
return x
|
|
|
|
|
return x1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchMatMul(MatMul):
|
|
|
|
|