|
|
|
@ -62,7 +62,6 @@ class IfElseTransformer(gast.NodeTransformer):
|
|
|
|
|
self.after_visit(self.root)
|
|
|
|
|
|
|
|
|
|
def visit_If(self, node):
|
|
|
|
|
assert isinstance(node, gast.If)
|
|
|
|
|
if_condition_visitor = IfConditionVisitor(node.test,
|
|
|
|
|
self.static_analysis_visitor)
|
|
|
|
|
need_transform = if_condition_visitor.is_control_flow()
|
|
|
|
@ -88,6 +87,22 @@ class IfElseTransformer(gast.NodeTransformer):
|
|
|
|
|
node = attribute.value
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_IfExp(self, node):
|
|
|
|
|
"""
|
|
|
|
|
Transformation with `true_fn(x) if Tensor > 0 else false_fn(x)`
|
|
|
|
|
"""
|
|
|
|
|
if_condition_visitor = IfConditionVisitor(node.test,
|
|
|
|
|
self.static_analysis_visitor)
|
|
|
|
|
need_transform = if_condition_visitor.is_control_flow()
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
if need_transform:
|
|
|
|
|
pred_node, new_assign_nodes = if_condition_visitor.transform()
|
|
|
|
|
new_node = create_cond_node(None, pred_node, node.body, node.orelse,
|
|
|
|
|
True)
|
|
|
|
|
return new_node
|
|
|
|
|
else:
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def after_visit(self, node):
|
|
|
|
|
"""
|
|
|
|
|
This function will add some postprocessing operations with node.
|
|
|
|
@ -130,7 +145,12 @@ def is_candidate_node(node):
|
|
|
|
|
"""
|
|
|
|
|
Nodes with specified type will be dependent on tensor.
|
|
|
|
|
"""
|
|
|
|
|
return isinstance(node, (gast.Compare, gast.BoolOp, gast.UnaryOp))
|
|
|
|
|
is_compare_node = isinstance(node,
|
|
|
|
|
(gast.Compare, gast.BoolOp, gast.UnaryOp))
|
|
|
|
|
# TODO(Aurelius84): `.numpy()` may be an customized function,
|
|
|
|
|
# and should consider a more elegant way to solve this problem.
|
|
|
|
|
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
|
|
|
|
|
return is_compare_node or has_numpy_attr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_with_none(node):
|
|
|
|
@ -223,6 +243,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
|
|
|
|
|
self.is_control_flow_num += 1
|
|
|
|
|
|
|
|
|
|
def visit_Call(self, node):
|
|
|
|
|
self._visit_Call(node)
|
|
|
|
|
if is_paddle_api(node):
|
|
|
|
|
self.is_control_flow_num += 1
|
|
|
|
|
return node
|
|
|
|
@ -238,8 +259,7 @@ class IsControlFlowVisitor(gast.NodeVisitor):
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _is_node_with_tensor(self, node, name_id):
|
|
|
|
|
tensor_types = set(
|
|
|
|
|
[NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES])
|
|
|
|
|
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
|
|
|
|
|
# Look up the node_var_type_map by name_id.
|
|
|
|
|
if self.node_var_type_map:
|
|
|
|
|
if name_id and isinstance(name_id, six.string_types):
|
|
|
|
@ -261,7 +281,9 @@ class IsControlFlowVisitor(gast.NodeVisitor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeTestTransformer(gast.NodeTransformer):
|
|
|
|
|
def __init__(self, ast_node, compare_nodes_with_tensor=set()):
|
|
|
|
|
def __init__(self, ast_node, compare_nodes_with_tensor=None):
|
|
|
|
|
if compare_nodes_with_tensor is None:
|
|
|
|
|
compare_nodes_with_tensor = set()
|
|
|
|
|
self.ast_root = ast_node
|
|
|
|
|
self._compare_nodes_with_tensor = compare_nodes_with_tensor
|
|
|
|
|
self._new_assign_nodes = []
|
|
|
|
@ -269,6 +291,15 @@ class NodeTestTransformer(gast.NodeTransformer):
|
|
|
|
|
def transform(self):
|
|
|
|
|
return self.visit(self.ast_root)
|
|
|
|
|
|
|
|
|
|
def visit_Call(self, node):
|
|
|
|
|
# self.generic_visit(node)
|
|
|
|
|
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
|
|
|
|
|
if isinstance(node.func, gast.Attribute):
|
|
|
|
|
attribute = node.func
|
|
|
|
|
if attribute.attr == 'numpy':
|
|
|
|
|
node = attribute.value
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_UnaryOp(self, node):
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
if isinstance(node.op, gast.Not):
|
|
|
|
@ -297,6 +328,7 @@ class NodeTestTransformer(gast.NodeTransformer):
|
|
|
|
|
if compare_with_none(
|
|
|
|
|
node) or node not in self._compare_nodes_with_tensor:
|
|
|
|
|
return self._create_bool_node(node)
|
|
|
|
|
self.generic_visit(node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _create_bool_node(self, node):
|
|
|
|
@ -656,46 +688,43 @@ def transform_if_else(node, root):
|
|
|
|
|
return true_func_node, false_func_node, return_name_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_cond_node(return_name_ids, pred, true_func, false_func):
|
|
|
|
|
def create_cond_node(return_name_ids,
|
|
|
|
|
pred,
|
|
|
|
|
true_func,
|
|
|
|
|
false_func,
|
|
|
|
|
is_if_expr=False):
|
|
|
|
|
"""
|
|
|
|
|
Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace
|
|
|
|
|
original `python if/else` statement.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def create_lambda_node(func_or_expr_node, is_if_expr=False):
|
|
|
|
|
body = func_or_expr_node
|
|
|
|
|
if not is_if_expr:
|
|
|
|
|
body = gast.Call(
|
|
|
|
|
func=gast.Name(
|
|
|
|
|
id=func_or_expr_node.name,
|
|
|
|
|
ctx=gast.Load(),
|
|
|
|
|
annotation=None,
|
|
|
|
|
type_comment=None),
|
|
|
|
|
args=[func_or_expr_node.args],
|
|
|
|
|
keywords=[])
|
|
|
|
|
|
|
|
|
|
lambda_node = gast.Lambda(
|
|
|
|
|
args=gast.arguments(
|
|
|
|
|
args=[],
|
|
|
|
|
posonlyargs=[],
|
|
|
|
|
vararg=None,
|
|
|
|
|
kwonlyargs=[],
|
|
|
|
|
kw_defaults=None,
|
|
|
|
|
kwarg=None,
|
|
|
|
|
defaults=[]),
|
|
|
|
|
body=body)
|
|
|
|
|
return lambda_node
|
|
|
|
|
|
|
|
|
|
cond_api = gast.parse('fluid.layers.cond').body[0].value
|
|
|
|
|
true_func_lambda = gast.Lambda(
|
|
|
|
|
args=gast.arguments(
|
|
|
|
|
args=[],
|
|
|
|
|
posonlyargs=[],
|
|
|
|
|
vararg=None,
|
|
|
|
|
kwonlyargs=[],
|
|
|
|
|
kw_defaults=None,
|
|
|
|
|
kwarg=None,
|
|
|
|
|
defaults=[]),
|
|
|
|
|
body=gast.Call(
|
|
|
|
|
func=gast.Name(
|
|
|
|
|
id=true_func.name,
|
|
|
|
|
ctx=gast.Load(),
|
|
|
|
|
annotation=None,
|
|
|
|
|
type_comment=None),
|
|
|
|
|
args=[true_func.args],
|
|
|
|
|
keywords=[]))
|
|
|
|
|
false_func_lambda = gast.Lambda(
|
|
|
|
|
args=gast.arguments(
|
|
|
|
|
args=[],
|
|
|
|
|
posonlyargs=[],
|
|
|
|
|
vararg=None,
|
|
|
|
|
kwonlyargs=[],
|
|
|
|
|
kw_defaults=None,
|
|
|
|
|
kwarg=None,
|
|
|
|
|
defaults=[]),
|
|
|
|
|
body=gast.Call(
|
|
|
|
|
func=gast.Name(
|
|
|
|
|
id=false_func.name,
|
|
|
|
|
ctx=gast.Load(),
|
|
|
|
|
annotation=None,
|
|
|
|
|
type_comment=None),
|
|
|
|
|
args=[false_func.args],
|
|
|
|
|
keywords=[]))
|
|
|
|
|
true_func_lambda = create_lambda_node(true_func, is_if_expr)
|
|
|
|
|
false_func_lambda = create_lambda_node(false_func, is_if_expr)
|
|
|
|
|
cond_layer = gast.Call(
|
|
|
|
|
func=cond_api,
|
|
|
|
|
args=[pred, true_func_lambda, false_func_lambda],
|
|
|
|
|