|
|
|
@ -101,7 +101,7 @@ class TestMeanAPI(unittest.TestCase):
|
|
|
|
|
fetch_list=[out1, out2, out3, out4, out5])
|
|
|
|
|
out_ref = np.mean(self.x)
|
|
|
|
|
for out in res:
|
|
|
|
|
self.assertEqual(np.allclose(out, out_ref), True)
|
|
|
|
|
self.assertEqual(np.allclose(out, out_ref, rtol=1e-04), True)
|
|
|
|
|
|
|
|
|
|
def test_api_dygraph(self):
|
|
|
|
|
paddle.disable_static(self.place)
|
|
|
|
@ -114,7 +114,9 @@ class TestMeanAPI(unittest.TestCase):
|
|
|
|
|
if len(axis) == 0:
|
|
|
|
|
axis = None
|
|
|
|
|
out_ref = np.mean(x, axis, keepdims=keepdim)
|
|
|
|
|
self.assertEqual(np.allclose(out.numpy(), out_ref), True)
|
|
|
|
|
self.assertEqual(
|
|
|
|
|
np.allclose(
|
|
|
|
|
out.numpy(), out_ref, rtol=1e-04), True)
|
|
|
|
|
|
|
|
|
|
test_case(self.x)
|
|
|
|
|
test_case(self.x, [])
|
|
|
|
|