@ -35,20 +35,34 @@ def while_loop_dyfunc(x):
return i
def for_loop_dyfunc ( max_len ) :
for i in range ( max_len ) :
ret = fluid . layers . zeros ( shape = [ 1 ] , dtype = ' float32 ' )
fluid . layers . increment ( ret , value = 2.0 , in_place = True )
return ret
class TestNameVisitor ( unittest . TestCase ) :
def setUp ( self ) :
self . loop_funcs = [ while_loop_dyfunc , for_loop_dyfunc ]
self . loop_var_names = [ set ( [ " i " , " x " ] ) , set ( [ " i " , " ret " , " max_len " ] ) ]
self . create_var_names = [ set ( ) , set ( [ " ret " ] ) ]
def test_loop_vars ( self ) :
test_func = inspect . getsource ( while_loop_dyfunc )
for i in range ( len ( self . loop_funcs ) ) :
func = self . loop_funcs [ i ]
test_func = inspect . getsource ( func )
gast_root = gast . parse ( test_func )
name_visitor = NameVisitor ( gast_root )
for node in gast . walk ( gast_root ) :
if isinstance ( node , gast . While ) :
if isinstance ( node , ( gast . While , gast . For ) ) :
loop_var_names , create_var_names = name_visitor . get_loop_var_names (
node )
self . assertEqual ( loop_var_names , set ( [ " i " , " x " ] ) )
self . assertEqual ( create_var_names , set ( ) )
self . assertEqual ( loop_var_names , self . loop_var_names [ i ] )
self . assertEqual ( create_var_names , self . create_var_names [ i ] )
class TestTransformWhile ( unittest . TestCase ) :
class TestTransformWhile Loop ( unittest . TestCase ) :
def setUp ( self ) :
self . place = fluid . CUDAPlace ( 0 ) if fluid . is_compiled_with_cuda (
) else fluid . CPUPlace ( )
@ -83,5 +97,35 @@ class TestTransformWhile(unittest.TestCase):
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformForLoop ( unittest . TestCase ) :
def setUp ( self ) :
self . place = fluid . CUDAPlace ( 0 ) if fluid . is_compiled_with_cuda (
) else fluid . CPUPlace ( )
self . len = 100
def _run_static ( self ) :
main_program = fluid . Program ( )
with fluid . program_guard ( main_program ) :
static_func = dygraph_to_static_graph ( for_loop_dyfunc )
out = static_func ( self . len )
exe = fluid . Executor ( self . place )
ret = exe . run ( main_program , fetch_list = out )
return ret
def _run_dygraph ( self ) :
with fluid . dygraph . guard ( self . place ) :
ret = for_loop_dyfunc ( self . len )
return ret . numpy ( )
def test_ast_to_func ( self ) :
static_numpy = self . _run_static ( )
self . assertTrue (
np . allclose (
np . full (
shape = ( 1 ) , fill_value = 2 , dtype = np . int32 ) , static_numpy ) )
self . _run_dygraph ( )
self . assertTrue ( np . allclose ( self . _run_dygraph ( ) , self . _run_static ( ) ) )
if __name__ == ' __main__ ' :
unittest . main ( )