|
|
|
@ -79,8 +79,10 @@ class TestBmmAPIError(unittest.TestCase):
|
|
|
|
|
y_data = np.arange(16, dtype='float32').reshape((2, 4, 2))
|
|
|
|
|
y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4))
|
|
|
|
|
y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2))
|
|
|
|
|
y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2))
|
|
|
|
|
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1)
|
|
|
|
|
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2)
|
|
|
|
|
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|