You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.7 KiB
120 lines
3.7 KiB
import unittest
|
|
import numpy as np
|
|
from op_test import OpTest
|
|
|
|
|
|
def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y):
|
|
BATCH_SIZE = 2
|
|
M = 3
|
|
N = 4
|
|
K = 5
|
|
if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y):
|
|
K = 1
|
|
if dim_X == 1:
|
|
if transpose_X:
|
|
shape_X = [M]
|
|
else:
|
|
shape_X = [K]
|
|
if dim_Y == 1:
|
|
if transpose_Y:
|
|
shape_Y = [N]
|
|
else:
|
|
shape_Y = [K]
|
|
if dim_X >= 2:
|
|
if transpose_X:
|
|
shape_X = [K, M]
|
|
else:
|
|
shape_X = [M, K]
|
|
if dim_X == 3:
|
|
shape_X = [BATCH_SIZE] + shape_X
|
|
if dim_Y >= 2:
|
|
if transpose_Y:
|
|
shape_Y = [N, K]
|
|
else:
|
|
shape_Y = [K, N]
|
|
if dim_Y == 3:
|
|
shape_Y = [BATCH_SIZE] + shape_Y
|
|
return shape_X, shape_Y
|
|
|
|
|
|
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
|
|
"""Reference forward implementation using np.matmul."""
|
|
# np.matmul does not support the transpose flags, so we manually
|
|
# transpose X and Y appropriately.
|
|
if transpose_X:
|
|
if X.ndim == 1:
|
|
X = X.reshape((X.size, 1))
|
|
elif X.ndim == 2:
|
|
X = X.T
|
|
elif X.ndim == 3:
|
|
X = np.transpose(X, (0, 2, 1))
|
|
else:
|
|
raise ValueError('X must have between 1 and 3 dimensions')
|
|
if transpose_Y:
|
|
if Y.ndim == 1:
|
|
Y = Y.reshape((1, Y.size))
|
|
elif Y.ndim == 2:
|
|
Y = Y.T
|
|
elif Y.ndim == 3:
|
|
Y = np.transpose(Y, (0, 2, 1))
|
|
else:
|
|
raise ValueError('Y must have between 1 and 3 dimensions')
|
|
Out = np.matmul(X, Y)
|
|
if not Out.shape:
|
|
# We do not support 0-dimensional Tensors (scalars). So where
|
|
# np.matmul outputs a scalar, we must convert to a Tensor of
|
|
# shape (1, ) instead.
|
|
# Everywhere else, we are compatible with np.matmul.
|
|
Out = np.array([Out], dtype="float32")
|
|
return Out
|
|
|
|
|
|
class Generator(object):
|
|
def setUp(self):
|
|
self.op_type = "matmul"
|
|
X = np.random.random(self.shape_X).astype("float32")
|
|
Y = np.random.random(self.shape_Y).astype("float32")
|
|
Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y)
|
|
self.inputs = {'X': X, 'Y': Y}
|
|
self.attrs = {
|
|
'transpose_X': self.transpose_X,
|
|
'transpose_Y': self.transpose_Y
|
|
}
|
|
self.outputs = {'Out': Out}
|
|
|
|
def test_check_output(self):
|
|
self.check_output(atol=1e-2)
|
|
|
|
def test_check_grad_normal(self):
|
|
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
|
|
|
|
def test_check_grad_ignore_x(self):
|
|
self.check_grad(
|
|
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
|
|
|
|
def test_check_grad_ignore_y(self):
|
|
self.check_grad(
|
|
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
|
|
|
|
|
|
# Generate test cases for all possibilities
|
|
for dim_X in [1, 2, 3]:
|
|
for dim_Y in [1, 2, 3]:
|
|
for transpose_X in [False, True]:
|
|
for transpose_Y in [False, True]:
|
|
test_name = (
|
|
'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format(
|
|
dim_X, dim_Y, transpose_X, transpose_Y))
|
|
shape_X, shape_Y = generate_compatible_shapes(
|
|
dim_X, dim_Y, transpose_X, transpose_Y)
|
|
test_class = type(test_name, (Generator, OpTest), {
|
|
'shape_X': shape_X,
|
|
'shape_Y': shape_Y,
|
|
'transpose_X': transpose_X,
|
|
'transpose_Y': transpose_Y,
|
|
})
|
|
globals()[test_name] = test_class
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|