add addmm dyg mode, test=develop (#24095)

revert-22778-infer_var_type
littletomatodonkey 5 years ago committed by GitHub
parent 96ffebef55
commit eec18202f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -133,5 +133,19 @@ class TestAddMMOp3(OpTest):
self.check_grad(['Input'], 'Out', no_grad_set=None)
class TestAddMMOp4(unittest.TestCase):
def test_api_with_dygraph(self):
np_input = np.random.random((20, 30)).astype(np.float32)
np_x = np.random.random((20, 6)).astype(np.float32)
np_y = np.random.random((6, 30)).astype(np.float32)
with fluid.dygraph.guard():
input = fluid.dygraph.to_variable(np_input)
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
out = paddle.tensor.addmm(input, x, y)
assert np.allclose(np_input + np.dot(np_x, np_y), out.numpy())
if __name__ == "__main__":
unittest.main()

@ -1000,6 +1000,10 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None):
# [[10.5 10.5]
# [10.5 10.5]]
"""
if in_dygraph_mode():
out = core.ops.addmm(input, x, y, "Alpha", alpha, "Beta", beta)
return out
inputs = {'Input': input, "X": x, "Y": y}
attrs = {'Alpha': alpha, 'Beta': beta}

Loading…
Cancel
Save