|
|
|
@ -280,6 +280,42 @@ class API_TestMm(unittest.TestCase):
|
|
|
|
|
"two value is\
|
|
|
|
|
{}\n{}, check diff!".format(np_res, expected_result))
|
|
|
|
|
|
|
|
|
|
def test_dygraph_with_out(self):
|
|
|
|
|
device = fluid.CPUPlace()
|
|
|
|
|
with fluid.dygraph.guard(device):
|
|
|
|
|
input_array1 = np.random.rand(3, 4).astype("float64")
|
|
|
|
|
input_array2 = np.random.rand(4, 3).astype("float64")
|
|
|
|
|
out_array = np.random.rand(3, 3).astype("float64")
|
|
|
|
|
data1 = fluid.dygraph.to_variable(input_array1)
|
|
|
|
|
data2 = fluid.dygraph.to_variable(input_array2)
|
|
|
|
|
paddle_out_holder = fluid.dygraph.to_variable(out_array)
|
|
|
|
|
out = paddle.mm(data1, data2, out=paddle_out_holder)
|
|
|
|
|
self.assertTrue(np.allclose(paddle_out_holder.numpy(), out.numpy()))
|
|
|
|
|
|
|
|
|
|
def test_dygraph_without_out(self):
|
|
|
|
|
device = fluid.CPUPlace()
|
|
|
|
|
with fluid.dygraph.guard(device):
|
|
|
|
|
input_array1 = np.random.rand(3, 4).astype("float64")
|
|
|
|
|
input_array2 = np.random.rand(4, 3).astype("float64")
|
|
|
|
|
data1 = fluid.dygraph.to_variable(input_array1)
|
|
|
|
|
data2 = fluid.dygraph.to_variable(input_array2)
|
|
|
|
|
out = paddle.mm(data1, data2)
|
|
|
|
|
expected_result = np.matmul(input_array1, input_array2)
|
|
|
|
|
self.assertTrue(np.allclose(expected_result, out.numpy()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Test_API_Matmul(unittest.TestCase):
|
|
|
|
|
def test_dygraph_without_out(self):
|
|
|
|
|
device = fluid.CPUPlace()
|
|
|
|
|
with fluid.dygraph.guard(device):
|
|
|
|
|
input_array1 = np.random.rand(3, 4).astype("float64")
|
|
|
|
|
input_array2 = np.random.rand(4, 3).astype("float64")
|
|
|
|
|
data1 = fluid.dygraph.to_variable(input_array1)
|
|
|
|
|
data2 = fluid.dygraph.to_variable(input_array2)
|
|
|
|
|
out = paddle.matmul(data1, data2)
|
|
|
|
|
expected_result = np.matmul(input_array1, input_array2)
|
|
|
|
|
self.assertTrue(np.allclose(expected_result, out.numpy()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class API_TestMmError(unittest.TestCase):
|
|
|
|
|
def test_errors(self):
|
|
|
|
|