|
|
|
@ -48,18 +48,37 @@ class TestSetValueBase(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSetValueApi(TestSetValueBase):
|
|
|
|
|
def test_api(self):
|
|
|
|
|
def _run_static(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
with paddle.static.program_guard(self.program):
|
|
|
|
|
x = paddle.ones(shape=self.shape, dtype=self.dtype)
|
|
|
|
|
self._call_setitem(x)
|
|
|
|
|
|
|
|
|
|
exe = paddle.static.Executor(paddle.CPUPlace())
|
|
|
|
|
out = exe.run(self.program, fetch_list=[x])
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def _run_dynamic(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
x = paddle.ones(shape=self.shape, dtype=self.dtype)
|
|
|
|
|
self._call_setitem(x)
|
|
|
|
|
out = x.numpy()
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def test_api(self):
|
|
|
|
|
static_out = self._run_static()
|
|
|
|
|
dynamic_out = self._run_dynamic()
|
|
|
|
|
self._get_answer()
|
|
|
|
|
|
|
|
|
|
error_msg = "\nIn {} mode: \nExpected res = \n{}, \n\nbut received : \n{}"
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
(self.data == out).all(),
|
|
|
|
|
msg="\nExpected res = \n{}, \n\nbut received : \n{}".format(
|
|
|
|
|
self.data, out))
|
|
|
|
|
(self.data == static_out).all(),
|
|
|
|
|
msg=error_msg.format("static", self.data, static_out))
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
(self.data == dynamic_out).all(),
|
|
|
|
|
msg=error_msg.format("dynamic", self.data, dynamic_out))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Test different type of item: int, Python slice, Paddle Tensor
|
|
|
|
@ -748,6 +767,7 @@ class TestError(TestSetValueBase):
|
|
|
|
|
exe.run(program)
|
|
|
|
|
|
|
|
|
|
def test_error(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
with paddle.static.program_guard(self.program):
|
|
|
|
|
self._value_type_error()
|
|
|
|
|
self._dtype_error()
|
|
|
|
|