|
|
|
@ -59,19 +59,18 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
|
|
|
|
|
X = X.reshape((X.size, 1))
|
|
|
|
|
elif X.ndim == 2:
|
|
|
|
|
X = X.T
|
|
|
|
|
elif X.ndim == 3:
|
|
|
|
|
X = np.transpose(X, (0, 2, 1))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('X must have between 1 and 3 dimensions')
|
|
|
|
|
dim = [i for i in range(len(X.shape))]
|
|
|
|
|
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
|
|
|
|
|
X = np.transpose(X, tuple(dim))
|
|
|
|
|
if transpose_Y:
|
|
|
|
|
if Y.ndim == 1:
|
|
|
|
|
Y = Y.reshape((1, Y.size))
|
|
|
|
|
elif Y.ndim == 2:
|
|
|
|
|
Y = Y.T
|
|
|
|
|
elif Y.ndim == 3:
|
|
|
|
|
Y = np.transpose(Y, (0, 2, 1))
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('Y must have between 1 and 3 dimensions')
|
|
|
|
|
dim = [i for i in range(len(Y.shape))]
|
|
|
|
|
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
|
|
|
|
|
Y = np.transpose(Y, tuple(dim))
|
|
|
|
|
|
|
|
|
|
Out = np.matmul(X, Y)
|
|
|
|
|
if not Out.shape:
|
|
|
|
|
# We do not support 0-dimensional Tensors (scalars). So where
|
|
|
|
@ -120,30 +119,49 @@ for dim_X in [1, 2, 3]:
|
|
|
|
|
dim_X, dim_Y, transpose_X, transpose_Y))
|
|
|
|
|
shape_X, shape_Y = generate_compatible_shapes(
|
|
|
|
|
dim_X, dim_Y, transpose_X, transpose_Y)
|
|
|
|
|
test_class = type(test_name, (Generator, OpTest), {
|
|
|
|
|
globals()[test_name] = type(test_name, (Generator, OpTest), {
|
|
|
|
|
'shape_X': shape_X,
|
|
|
|
|
'shape_Y': shape_Y,
|
|
|
|
|
'transpose_X': transpose_X,
|
|
|
|
|
'transpose_Y': transpose_Y,
|
|
|
|
|
})
|
|
|
|
|
globals()[test_name] = test_class
|
|
|
|
|
|
|
|
|
|
# Test case 4-dim
|
|
|
|
|
dim_X = 4
|
|
|
|
|
dim_Y = 4
|
|
|
|
|
transpose_X = False
|
|
|
|
|
transpose_Y = False
|
|
|
|
|
test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
|
|
|
|
|
dim_X, dim_Y, transpose_X, transpose_Y))
|
|
|
|
|
|
|
|
|
|
shape_X = [2, 2, 2, 3]
|
|
|
|
|
shape_Y = [2, 2, 3, 4]
|
|
|
|
|
test_class = type(test_name, (Generator, OpTest), {
|
|
|
|
|
'shape_X': shape_X,
|
|
|
|
|
'shape_Y': shape_Y,
|
|
|
|
|
'transpose_X': transpose_X,
|
|
|
|
|
'transpose_Y': transpose_Y,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_compatible_shapes(dim, transpose_X, transpose_Y):
|
|
|
|
|
M = 2
|
|
|
|
|
N = 4
|
|
|
|
|
K = 3
|
|
|
|
|
shape_X = [2 for _ in range(dim - 2)]
|
|
|
|
|
shape_Y = [2 for _ in range(dim - 2)]
|
|
|
|
|
|
|
|
|
|
if transpose_X:
|
|
|
|
|
shape_X = shape_X + [K, M]
|
|
|
|
|
else:
|
|
|
|
|
shape_X = shape_X + [M, K]
|
|
|
|
|
|
|
|
|
|
if transpose_Y:
|
|
|
|
|
shape_Y = shape_Y + [N, K]
|
|
|
|
|
else:
|
|
|
|
|
shape_Y = shape_Y + [K, N]
|
|
|
|
|
|
|
|
|
|
return shape_X, shape_Y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Test case n-dim
|
|
|
|
|
for dim in [4]:
|
|
|
|
|
for transpose_X in [False, True]:
|
|
|
|
|
for transpose_Y in [False, True]:
|
|
|
|
|
test_name = (
|
|
|
|
|
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
|
|
|
|
|
dim, dim, transpose_X, transpose_Y))
|
|
|
|
|
shape_X, shape_Y = generate_compatible_shapes(dim, transpose_X,
|
|
|
|
|
transpose_Y)
|
|
|
|
|
globals()[test_name] = type(test_name, (Generator, OpTest), {
|
|
|
|
|
'shape_X': shape_X,
|
|
|
|
|
'shape_Y': shape_Y,
|
|
|
|
|
'transpose_X': transpose_X,
|
|
|
|
|
'transpose_Y': transpose_Y,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|