revert-24895-update_cub
zhupengyang 5 years ago committed by GitHub
parent cd48bdad31
commit 9317e51fa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -86,6 +86,7 @@ class TestMeanAPI(unittest.TestCase):
else paddle.CPUPlace()
def test_api_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_shape)
out1 = paddle.mean(x)
@ -102,7 +103,9 @@ class TestMeanAPI(unittest.TestCase):
for out in res:
self.assertEqual(np.allclose(out, out_ref), True)
def test_api_imperative(self):
def test_api_dygraph(self):
paddle.disable_static(self.place)
def test_case(x, axis=None, keepdim=False):
x_tensor = paddle.to_variable(x)
out = paddle.mean(x_tensor, axis, keepdim)
@ -113,7 +116,6 @@ class TestMeanAPI(unittest.TestCase):
out_ref = np.mean(x, axis, keepdims=keepdim)
self.assertEqual(np.allclose(out.numpy(), out_ref), True)
paddle.disable_static(self.place)
test_case(self.x)
test_case(self.x, [])
test_case(self.x, -1)
@ -125,6 +127,7 @@ class TestMeanAPI(unittest.TestCase):
paddle.enable_static()
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12], 'int8')
self.assertRaises(TypeError, paddle.mean, x)

Loading…
Cancel
Save