@ -15,12 +15,14 @@
from __future__ import print_function
import unittest
from paddle . fluid . framework import default_main_program , Program , convert_np_dtype_to_dtype_ , in_dygraph_mode
import numpy as np
import six
import paddle
import paddle . fluid as fluid
import paddle . fluid . layers as layers
import paddle . fluid . core as core
import numpy as np
import paddle . fluid . layers as layers
from paddle . fluid . framework import default_main_program , Program , convert_np_dtype_to_dtype_ , in_dygraph_mode
class TestVarBase ( unittest . TestCase ) :
@ -403,5 +405,52 @@ class TestVarBase(unittest.TestCase):
self . assertListEqual ( list ( var_base . shape ) , list ( static_var . shape ) )
class TestVarBaseSetitem ( unittest . TestCase ) :
def setUp ( self ) :
paddle . disable_static ( )
self . tensor_x = paddle . to_tensor ( np . ones ( ( 4 , 2 , 3 ) ) . astype ( np . float32 ) )
self . np_value = np . random . random ( ( 2 , 3 ) ) . astype ( np . float32 )
self . tensor_value = paddle . to_tensor ( self . np_value )
def _test ( self , value ) :
paddle . disable_static ( )
id_origin = id ( self . tensor_x )
self . tensor_x [ 0 ] = value
if isinstance ( value , ( six . integer_types , float ) ) :
result = np . zeros ( ( 2 , 3 ) ) . astype ( np . float32 ) + value
else :
result = self . np_value
self . assertTrue ( np . array_equal ( self . tensor_x [ 0 ] . numpy ( ) , result ) )
self . assertEqual ( id_origin , id ( self . tensor_x ) )
self . tensor_x [ 1 : 2 ] = value
self . assertTrue ( np . array_equal ( self . tensor_x [ 1 ] . numpy ( ) , result ) )
self . assertEqual ( id_origin , id ( self . tensor_x ) )
self . tensor_x [ . . . ] = value
self . assertTrue ( np . array_equal ( self . tensor_x [ 3 ] . numpy ( ) , result ) )
self . assertEqual ( id_origin , id ( self . tensor_x ) )
def test_value_tensor ( self ) :
paddle . disable_static ( )
self . _test ( self . tensor_value )
def test_value_numpy ( self ) :
paddle . disable_static ( )
self . _test ( self . np_value )
def test_value_int ( self ) :
paddle . disable_static ( )
self . _test ( 10 )
def test_value_float ( self ) :
paddle . disable_static ( )
self . _test ( 3.3 )
if __name__ == ' __main__ ' :
unittest . main ( )