|
|
|
@ -27,8 +27,53 @@ from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
|
|
|
|
|
TRUE_FUNC_PRFIX = 'true_fn'
|
|
|
|
|
FALSE_FUNC_PRFIX = 'false_fn'
|
|
|
|
|
TRUE_FUNC_PREFIX = 'true_fn'
|
|
|
|
|
FALSE_FUNC_PREFIX = 'false_fn'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IsControlFlowIfVisitor(gast.NodeTransformer):
|
|
|
|
|
"""
|
|
|
|
|
Judge whether the node.test from Dygraph code dependent on paddle Tensor.
|
|
|
|
|
If does, it should satisfy:
|
|
|
|
|
1. must involve at least one var whose type is Tensor.
|
|
|
|
|
2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1].
|
|
|
|
|
3. involve Tensor.shape[i] and the shape[i] is unknown in compile time.
|
|
|
|
|
The following examples should not be considered as control_flow_if:
|
|
|
|
|
1. `if Tensor_var` or `if Tensor_var is None`
|
|
|
|
|
2. if Tensor.shape[i] is determined with fixed value (not -1 or None)
|
|
|
|
|
|
|
|
|
|
Note: pred in ConditionalBlock require variable, which means all vars should be Tensor
|
|
|
|
|
or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]).
|
|
|
|
|
|
|
|
|
|
TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i],
|
|
|
|
|
because reshape_op may be called before this statement.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, node):
|
|
|
|
|
self.node = node
|
|
|
|
|
self.is_control_flow = False
|
|
|
|
|
|
|
|
|
|
def ast_visit(self):
|
|
|
|
|
self.visit(self.node)
|
|
|
|
|
return self.is_control_flow
|
|
|
|
|
|
|
|
|
|
def visit_Compare(self, node):
|
|
|
|
|
for child in gast.walk(node):
|
|
|
|
|
if isinstance(child, gast.Subscript):
|
|
|
|
|
self._visit_Subscript(child)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _visit_Subscript(self, node):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
if isinstance(node.value, gast.Call):
|
|
|
|
|
self._visit_Call(node.value)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _visit_Call(self, node):
|
|
|
|
|
assert isinstance(node, gast.Call)
|
|
|
|
|
if isinstance(node.func, gast.Attribute):
|
|
|
|
|
attr_node = node.func
|
|
|
|
|
self.is_control_flow = (attr_node.attr == 'numpy')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_control_flow_if(node):
|
|
|
|
@ -36,7 +81,10 @@ def is_control_flow_if(node):
|
|
|
|
|
Determine whether the node is a plain python `if statement` or
|
|
|
|
|
control flow in Paddle.
|
|
|
|
|
"""
|
|
|
|
|
return True
|
|
|
|
|
assert isinstance(
|
|
|
|
|
node, gast.AST
|
|
|
|
|
), "Type of input node should be gast.AST, but received %s." % type(node)
|
|
|
|
|
return IsControlFlowIfVisitor(node).ast_visit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_name_ids(nodes, not_name_set=None, node_black_list=None):
|
|
|
|
@ -228,12 +276,12 @@ def transform_if_else(node, root):
|
|
|
|
|
|
|
|
|
|
true_func_node = create_funcDef_node(
|
|
|
|
|
node.body,
|
|
|
|
|
name=unique_name.generate(TRUE_FUNC_PRFIX),
|
|
|
|
|
name=unique_name.generate(TRUE_FUNC_PREFIX),
|
|
|
|
|
input_args=parse_cond_args(if_name_ids, modified_name_ids),
|
|
|
|
|
return_name_ids=return_name_ids)
|
|
|
|
|
false_func_node = create_funcDef_node(
|
|
|
|
|
node.orelse,
|
|
|
|
|
name=unique_name.generate(FALSE_FUNC_PRFIX),
|
|
|
|
|
name=unique_name.generate(FALSE_FUNC_PREFIX),
|
|
|
|
|
input_args=parse_cond_args(else_name_ids, modified_name_ids),
|
|
|
|
|
return_name_ids=return_name_ids)
|
|
|
|
|
|
|
|
|
@ -309,7 +357,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
|
|
|
|
|
f = tempfile.NamedTemporaryFile(
|
|
|
|
|
mode='w', suffix='.py', delete=False, encoding='utf-8')
|
|
|
|
|
|
|
|
|
|
# TODO(Aurelius84): more elegent way to transform ast into callable object
|
|
|
|
|
# TODO(Aurelius84): more elegant way to transform ast into callable object
|
|
|
|
|
import_str = "import paddle\n" \
|
|
|
|
|
"import paddle.fluid as fluid\n" \
|
|
|
|
|
"import paddle.fluid.layers as layers\n"
|
|
|
|
|