|
|
|
@ -26,6 +26,8 @@ import atexit
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType, StaticAnalysisVisitor
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
|
|
|
|
|
|
|
|
|
|
TRUE_FUNC_PREFIX = 'true_fn'
|
|
|
|
|
FALSE_FUNC_PREFIX = 'false_fn'
|
|
|
|
@ -49,23 +51,36 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
|
|
|
|
|
because reshape_op may be called before this statement.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, node):
|
|
|
|
|
self.node = node
|
|
|
|
|
def __init__(self, static_analysis_visitor):
|
|
|
|
|
self.static_analysis_visitor = static_analysis_visitor
|
|
|
|
|
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
|
|
|
|
|
)
|
|
|
|
|
self.is_control_flow = False
|
|
|
|
|
|
|
|
|
|
def ast_visit(self):
|
|
|
|
|
self.visit(self.node)
|
|
|
|
|
def transform(self, node):
|
|
|
|
|
if self._is_candidate_node(node):
|
|
|
|
|
self.visit(node)
|
|
|
|
|
return self.is_control_flow
|
|
|
|
|
|
|
|
|
|
def visit_BoolOp(self, node):
|
|
|
|
|
for child in node.values:
|
|
|
|
|
if not self._is_candidate_node(child):
|
|
|
|
|
continue
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_Compare(self, node):
|
|
|
|
|
for child in gast.walk(node):
|
|
|
|
|
if isinstance(child, gast.Subscript):
|
|
|
|
|
self._visit_Subscript(child)
|
|
|
|
|
# Ignores child node with `if x` or `if x is None`
|
|
|
|
|
if not self._compare_with_none(node):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
for child in gast.walk(node):
|
|
|
|
|
if isinstance(child, gast.Subscript):
|
|
|
|
|
self._visit_Subscript(child)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _visit_Subscript(self, node):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
if isinstance(node.value, gast.Call):
|
|
|
|
|
if hasattr(node, 'value') and isinstance(node.value, gast.Call):
|
|
|
|
|
self._visit_Call(node.value)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
@ -73,10 +88,40 @@ class IsControlFlowIfVisitor(gast.NodeTransformer):
|
|
|
|
|
assert isinstance(node, gast.Call)
|
|
|
|
|
if isinstance(node.func, gast.Attribute):
|
|
|
|
|
attr_node = node.func
|
|
|
|
|
self.is_control_flow = (attr_node.attr == 'numpy')
|
|
|
|
|
if attr_node.attr == 'numpy':
|
|
|
|
|
self.is_control_flow = True
|
|
|
|
|
|
|
|
|
|
def visit_Call(self, node):
|
|
|
|
|
if is_paddle_api(node):
|
|
|
|
|
self.is_control_flow = True
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_Name(self, node):
|
|
|
|
|
wrapper_node = self.node_to_wrapper_map.get(node, None)
|
|
|
|
|
if wrapper_node is not None:
|
|
|
|
|
if wrapper_node.node_var_type & {
|
|
|
|
|
NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES
|
|
|
|
|
}:
|
|
|
|
|
self.is_control_flow = True
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _is_candidate_node(self, node):
|
|
|
|
|
return isinstance(node, (gast.Compare, gast.BoolOp))
|
|
|
|
|
|
|
|
|
|
def _compare_with_none(self, node):
|
|
|
|
|
if isinstance(node, gast.Compare):
|
|
|
|
|
for child in [node.left, node.comparators]:
|
|
|
|
|
# node.comparators is a list.
|
|
|
|
|
if isinstance(child, list): child = child[0]
|
|
|
|
|
if (isinstance(child, gast.Constant) and
|
|
|
|
|
child.value is None) or (
|
|
|
|
|
isinstance(child, gast.Name) and
|
|
|
|
|
child.id == 'None'):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_control_flow_if(node):
|
|
|
|
|
def is_control_flow_if(node, static_analysis_visitor=None):
|
|
|
|
|
"""
|
|
|
|
|
Determine whether the node is a plain python `if statement` or
|
|
|
|
|
control flow in Paddle.
|
|
|
|
@ -84,7 +129,9 @@ def is_control_flow_if(node):
|
|
|
|
|
assert isinstance(
|
|
|
|
|
node, gast.AST
|
|
|
|
|
), "Type of input node should be gast.AST, but received %s." % type(node)
|
|
|
|
|
return IsControlFlowIfVisitor(node).ast_visit()
|
|
|
|
|
if static_analysis_visitor is None:
|
|
|
|
|
static_analysis_visitor = StaticAnalysisVisitor(node)
|
|
|
|
|
return IsControlFlowIfVisitor(static_analysis_visitor).transform(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
|
|
|
|
|