@ -42,7 +42,7 @@ def func_to_test2(x):
return x
result_var_type2 = { ' m ' : NodeVarType . INT }
result_var_type2 = { ' m ' : { NodeVarType . INT } }
def func_to_test3 ( ) :
@ -59,16 +59,16 @@ def func_to_test3():
result_var_type3 = {
' a ' : NodeVarType . INT ,
' b ' : NodeVarType . FLOAT ,
' c ' : NodeVarType . FLOAT ,
' d ' : NodeVarType . FLOAT ,
' e ' : NodeVarType . BOOLEAN ,
' f ' : NodeVarType . INT ,
' g ' : NodeVarType . STRING ,
' h ' : NodeVarType . NONE ,
' i ' : NodeVarType . BOOLEAN ,
' j ' : NodeVarType . UNKNOWN
' a ' : { NodeVarType . INT } ,
' b ' : { NodeVarType . FLOAT } ,
' c ' : { NodeVarType . FLOAT } ,
' d ' : { NodeVarType . FLOAT } ,
' e ' : { NodeVarType . BOOLEAN } ,
' f ' : { NodeVarType . INT } ,
' g ' : { NodeVarType . STRING } ,
' h ' : { NodeVarType . NONE } ,
' i ' : { NodeVarType . BOOLEAN } ,
' j ' : { NodeVarType . UNKNOWN }
}
@ -81,15 +81,48 @@ def func_to_test4():
result_var_type4 = {
' a ' : NodeVarType . NUMPY_NDARRAY ,
' b ' : NodeVarType . NUMPY_NDARRAY ,
' c ' : NodeVarType . TENSOR ,
' d ' : NodeVarType . TENSOR
' a ' : { NodeVarType . NUMPY_NDARRAY } ,
' b ' : { NodeVarType . NUMPY_NDARRAY } ,
' c ' : { NodeVarType . TENSOR } ,
' d ' : { NodeVarType . TENSOR }
}
test_funcs = [ func_to_test1 , func_to_test2 , func_to_test3 , func_to_test4 ]
def func_to_test5 ( ) :
def inner_int_func ( ) :
return 1
def inner_bool_float_func ( x ) :
a = 1.0
if x > 0 :
return a
return False
def inner_unknown_func ( x ) :
return x
a = inner_int_func ( )
b = inner_bool_float_func ( 3 )
c = inner_unknown_func ( None )
d = paddle . fluid . data ( ' x ' , [ 1 , 2 ] )
result_var_type5 = {
' a ' : { NodeVarType . INT } ,
' b ' : { NodeVarType . FLOAT , NodeVarType . BOOLEAN } ,
' c ' : { NodeVarType . UNKNOWN } ,
' d ' : { NodeVarType . PADDLE_RETURN_TYPES } ,
' inner_int_func ' : { NodeVarType . INT } ,
' inner_bool_float_func ' : { NodeVarType . FLOAT , NodeVarType . BOOLEAN } ,
' inner_unknown_func ' : { NodeVarType . UNKNOWN } ,
}
test_funcs = [
func_to_test1 , func_to_test2 , func_to_test3 , func_to_test4 , func_to_test5
]
result_var_type = [
result_var_type1 , result_var_type2 , result_var_type3 , result_var_type4
result_var_type1 , result_var_type2 , result_var_type3 , result_var_type4 ,
result_var_type5
]
@ -117,7 +150,7 @@ class TestStaticAnalysis(unittest.TestCase):
self . _check_wrapper ( wrapper_root , node_to_wrapper_map )
def test_var_env ( self ) :
for i in range ( 4 ) :
for i in range ( 5 ) :
func = test_funcs [ i ]
var_type = result_var_type [ i ]
test_source_code = inspect . getsource ( func )
@ -125,6 +158,11 @@ class TestStaticAnalysis(unittest.TestCase):
print ( gast . dump ( ast_root ) )
visitor = StaticAnalysisVisitor ( ast_root )
var_env = visitor . get_var_env ( )
# There must be 1 sub scope for the test function
self . assertEqual ( 1 , len ( var_env . cur_scope . sub_scopes ) )
var_env . cur_scope = var_env . cur_scope . sub_scopes [ 0 ]
scope_var_type = var_env . get_scope_var_type ( )
self . assertEqual ( len ( scope_var_type ) , len ( var_type ) )
for name in scope_var_type :