fix bmm enforce equal batch (#27694)

my_2.0rc
yaoxuefeng 4 years ago committed by GitHub
parent 742cbe6660
commit e496640bf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -79,8 +79,10 @@ class TestBmmAPIError(unittest.TestCase):
y_data = np.arange(16, dtype='float32').reshape((2, 4, 2))
y_data_wrong1 = np.arange(16, dtype='float32').reshape((2, 2, 4))
y_data_wrong2 = np.arange(16, dtype='float32').reshape((2, 2, 2, 2))
y_data_wrong3 = np.arange(24, dtype='float32').reshape((3, 4, 2))
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong1)
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong2)
self.assertRaises(ValueError, paddle.bmm, x_data, y_data_wrong3)
if __name__ == "__main__":

@ -848,6 +848,10 @@ def bmm(x, y, name=None):
raise ValueError(
"x's width must be equal with y's height. But received x's shape: {}, y's shape: {}".
format(x_shape, y_shape))
if x_shape[0] != y_shape[0]:
raise ValueError(
"x's batch (shape[0]) must be equal with y's batch (shape[0]). But received x's shape: {}, y's shape: {}".
format(x_shape, y_shape))
helper = LayerHelper('bmm', **locals())
if in_dygraph_mode():
return core.ops.bmm(x, y)

Loading…
Cancel
Save