|
|
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import gast
|
|
|
|
|
import warnings
|
|
|
|
|
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api
|
|
|
|
|
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list
|
|
|
|
|
|
|
|
|
|
__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']
|
|
|
|
|
|
|
|
|
@ -260,20 +260,27 @@ class StaticAnalysisVisitor(object):
|
|
|
|
|
def get_var_env(self):
|
|
|
|
|
return self.var_env
|
|
|
|
|
|
|
|
|
|
def _get_constant_node_type(self, node):
|
|
|
|
|
assert isinstance(node, gast.Constant), \
|
|
|
|
|
"Type of input node should be gast.Constant, but received %s" % type(node)
|
|
|
|
|
# singleton: None, True or False
|
|
|
|
|
if node.value is None:
|
|
|
|
|
return {NodeVarType.NONE}
|
|
|
|
|
if isinstance(node.value, bool):
|
|
|
|
|
return {NodeVarType.BOOLEAN}
|
|
|
|
|
if isinstance(node.value, int):
|
|
|
|
|
return {NodeVarType.INT}
|
|
|
|
|
if isinstance(node.value, float):
|
|
|
|
|
return {NodeVarType.FLOAT}
|
|
|
|
|
if isinstance(node.value, str):
|
|
|
|
|
return {NodeVarType.STRING}
|
|
|
|
|
|
|
|
|
|
return {NodeVarType.UNKNOWN}
|
|
|
|
|
|
|
|
|
|
def _get_node_var_type(self, cur_wrapper):
|
|
|
|
|
node = cur_wrapper.node
|
|
|
|
|
if isinstance(node, gast.Constant):
|
|
|
|
|
# singleton: None, True or False
|
|
|
|
|
if node.value is None:
|
|
|
|
|
return {NodeVarType.NONE}
|
|
|
|
|
if isinstance(node.value, bool):
|
|
|
|
|
return {NodeVarType.BOOLEAN}
|
|
|
|
|
if isinstance(node.value, int):
|
|
|
|
|
return {NodeVarType.INT}
|
|
|
|
|
if isinstance(node.value, float):
|
|
|
|
|
return {NodeVarType.FLOAT}
|
|
|
|
|
if isinstance(node.value, str):
|
|
|
|
|
return {NodeVarType.STRING}
|
|
|
|
|
return self._get_constant_node_type(node)
|
|
|
|
|
|
|
|
|
|
if isinstance(node, gast.BoolOp):
|
|
|
|
|
return {NodeVarType.BOOLEAN}
|
|
|
|
@ -308,8 +315,28 @@ class StaticAnalysisVisitor(object):
|
|
|
|
|
if isinstance(node, gast.Name):
|
|
|
|
|
if node.id == "None":
|
|
|
|
|
return {NodeVarType.NONE}
|
|
|
|
|
if node.id == "True" or node.id == "False":
|
|
|
|
|
if node.id in {"True", "False"}:
|
|
|
|
|
return {NodeVarType.BOOLEAN}
|
|
|
|
|
# If node is child of functionDef.arguments
|
|
|
|
|
parent_node_wrapper = cur_wrapper.parent
|
|
|
|
|
if parent_node_wrapper and isinstance(parent_node_wrapper.node,
|
|
|
|
|
gast.arguments):
|
|
|
|
|
parent_node = parent_node_wrapper.node
|
|
|
|
|
var_type = {NodeVarType.UNKNOWN}
|
|
|
|
|
if parent_node.defaults:
|
|
|
|
|
index = index_in_list(parent_node.args, node)
|
|
|
|
|
args_len = len(parent_node.args)
|
|
|
|
|
if index != -1 and args_len - index <= len(
|
|
|
|
|
parent_node.defaults):
|
|
|
|
|
defaults_node = parent_node.defaults[index - args_len]
|
|
|
|
|
if isinstance(defaults_node, gast.Constant):
|
|
|
|
|
var_type = self._get_constant_node_type(
|
|
|
|
|
defaults_node)
|
|
|
|
|
|
|
|
|
|
# Add node with identified type into cur_env.
|
|
|
|
|
self.var_env.set_var_type(node.id, var_type)
|
|
|
|
|
return var_type
|
|
|
|
|
|
|
|
|
|
return self.var_env.get_var_type(node.id)
|
|
|
|
|
|
|
|
|
|
if isinstance(node, gast.Return):
|
|
|
|
|