|
|
|
@ -648,7 +648,7 @@ class MatMul(PrimitiveWithInfer):
|
|
|
|
|
raise ValueError('MatMul input x, y should be the same dimension size and should be '
|
|
|
|
|
+ f'equal to 2, while x size = {len(x)}, y size= {len(y)}')
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x, y, bias=None):
|
|
|
|
|
def infer_shape(self, x, y):
|
|
|
|
|
self.check_shape_size(x, y)
|
|
|
|
|
cls_name = self.name
|
|
|
|
|
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
|
|
|
|
@ -673,7 +673,7 @@ class MatMul(PrimitiveWithInfer):
|
|
|
|
|
ret_dims = x[: -2] + [x_last[self.transpose_a], y_last[not self.transpose_b]]
|
|
|
|
|
return ret_dims
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, x, y, bias=None):
|
|
|
|
|
def infer_dtype(self, x, y):
|
|
|
|
|
args = {"x": x, "y": y}
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
|
|
|
|
|
if x.element_type() == mstype.int8:
|
|
|
|
|