|
|
|
@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase):
|
|
|
|
|
else paddle.CPUPlace()
|
|
|
|
|
|
|
|
|
|
def test_api_static(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
with paddle.static.program_guard(paddle.static.Program()):
|
|
|
|
|
x = paddle.data('X', self.x_shape)
|
|
|
|
|
out1 = paddle.mean(x)
|
|
|
|
@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase):
|
|
|
|
|
for out in res:
|
|
|
|
|
self.assertEqual(np.allclose(out, out_ref), True)
|
|
|
|
|
|
|
|
|
|
def test_api_imperative(self):
|
|
|
|
|
def test_api_dygraph(self):
|
|
|
|
|
paddle.disable_static(self.place)
|
|
|
|
|
|
|
|
|
|
def test_case(x, axis=None, keepdim=False):
|
|
|
|
|
x_tensor = paddle.to_variable(x)
|
|
|
|
|
out = paddle.mean(x_tensor, axis, keepdim)
|
|
|
|
@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase):
|
|
|
|
|
out_ref = np.mean(x, axis, keepdims=keepdim)
|
|
|
|
|
self.assertEqual(np.allclose(out.numpy(), out_ref), True)
|
|
|
|
|
|
|
|
|
|
paddle.disable_static(self.place)
|
|
|
|
|
test_case(self.x)
|
|
|
|
|
test_case(self.x, [])
|
|
|
|
|
test_case(self.x, -1)
|
|
|
|
@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
|
|
|
|
def test_errors(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
with paddle.static.program_guard(paddle.static.Program()):
|
|
|
|
|
x = paddle.data('X', [10, 12], 'int8')
|
|
|
|
|
self.assertRaises(TypeError, paddle.mean, x)
|
|
|
|
|