|
|
|
@ -631,10 +631,14 @@ class TestVarBase(unittest.TestCase):
|
|
|
|
|
class TestVarBaseSetitem(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
|
|
|
|
|
self.np_value = np.random.random((2, 3)).astype(np.float32)
|
|
|
|
|
self.set_dtype()
|
|
|
|
|
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype))
|
|
|
|
|
self.np_value = np.random.random((2, 3)).astype(self.dtype)
|
|
|
|
|
self.tensor_value = paddle.to_tensor(self.np_value)
|
|
|
|
|
|
|
|
|
|
def set_dtype(self):
|
|
|
|
|
self.dtype = "int32"
|
|
|
|
|
|
|
|
|
|
def _test(self, value):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
self.assertEqual(self.tensor_x.inplace_version, 0)
|
|
|
|
@ -644,7 +648,7 @@ class TestVarBaseSetitem(unittest.TestCase):
|
|
|
|
|
self.assertEqual(self.tensor_x.inplace_version, 1)
|
|
|
|
|
|
|
|
|
|
if isinstance(value, (six.integer_types, float)):
|
|
|
|
|
result = np.zeros((2, 3)).astype(np.float32) + value
|
|
|
|
|
result = np.zeros((2, 3)).astype(self.dtype) + value
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
result = self.np_value
|
|
|
|
@ -674,11 +678,26 @@ class TestVarBaseSetitem(unittest.TestCase):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
self._test(10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestVarBaseSetitemInt64(TestVarBaseSetitem):
|
|
|
|
|
def set_dtype(self):
|
|
|
|
|
self.dtype = "int64"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestVarBaseSetitemFp32(TestVarBaseSetitem):
|
|
|
|
|
def set_dtype(self):
|
|
|
|
|
self.dtype = "float32"
|
|
|
|
|
|
|
|
|
|
def test_value_float(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
self._test(3.3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestVarBaseSetitemFp64(TestVarBaseSetitem):
|
|
|
|
|
def set_dtype(self):
|
|
|
|
|
self.dtype = "float64"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestVarBaseInplaceVersion(unittest.TestCase):
|
|
|
|
|
def test_setitem(self):
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|