|
|
|
@ -52,7 +52,6 @@ class TestSetValueApi(TestSetValueBase):
|
|
|
|
|
|
|
|
|
|
exe = paddle.static.Executor(paddle.CPUPlace())
|
|
|
|
|
out = exe.run(self.program, fetch_list=[x])
|
|
|
|
|
|
|
|
|
|
self._get_answer()
|
|
|
|
|
self.assertTrue(
|
|
|
|
|
(self.data == out).all(),
|
|
|
|
@ -60,7 +59,7 @@ class TestSetValueApi(TestSetValueBase):
|
|
|
|
|
self.data, out))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Test different type of item: int, python slice
|
|
|
|
|
# 1. Test different type of item: int, python slice, Ellipsis
|
|
|
|
|
class TestSetValueItemInt(TestSetValueApi):
|
|
|
|
|
def _call_setitem(self, x):
|
|
|
|
|
x[0] = self.value
|
|
|
|
@ -101,6 +100,38 @@ class TestSetValueItemSlice4(TestSetValueApi):
|
|
|
|
|
self.data[0:, 1:2, :] = self.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSetValueItemEllipsis1(TestSetValueApi):
|
|
|
|
|
def _call_setitem(self, x):
|
|
|
|
|
x[0:, ..., 1:] = self.value
|
|
|
|
|
|
|
|
|
|
def _get_answer(self):
|
|
|
|
|
self.data[0:, ..., 1:] = self.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSetValueItemEllipsis2(TestSetValueApi):
|
|
|
|
|
def _call_setitem(self, x):
|
|
|
|
|
x[0:, ...] = self.value
|
|
|
|
|
|
|
|
|
|
def _get_answer(self):
|
|
|
|
|
self.data[0:, ...] = self.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSetValueItemEllipsis3(TestSetValueApi):
|
|
|
|
|
def _call_setitem(self, x):
|
|
|
|
|
x[..., 1:] = self.value
|
|
|
|
|
|
|
|
|
|
def _get_answer(self):
|
|
|
|
|
self.data[..., 1:] = self.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSetValueItemEllipsis4(TestSetValueApi):
|
|
|
|
|
def _call_setitem(self, x):
|
|
|
|
|
x[...] = self.value
|
|
|
|
|
|
|
|
|
|
def _get_answer(self):
|
|
|
|
|
self.data[...] = self.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Test different type of value: int, float, numpy.ndarray, Tensor
|
|
|
|
|
# 2.1 value is int32, int64, float32, float64, bool
|
|
|
|
|
|
|
|
|
@ -499,6 +530,12 @@ class TestError(TestSetValueBase):
|
|
|
|
|
x = paddle.ones(shape=self.shape, dtype=self.dtype)
|
|
|
|
|
x[0:1:2] = self.value
|
|
|
|
|
|
|
|
|
|
def _ellipsis_error(self):
|
|
|
|
|
with self.assertRaisesRegexp(
|
|
|
|
|
IndexError, "An index can only have a single ellipsis"):
|
|
|
|
|
x = paddle.ones(shape=self.shape, dtype=self.dtype)
|
|
|
|
|
x[..., ...] = self.value
|
|
|
|
|
|
|
|
|
|
def _broadcast_mismatch(self):
|
|
|
|
|
program = paddle.static.Program()
|
|
|
|
|
with paddle.static.program_guard(program):
|
|
|
|
|