[Dy2Stat] Optimize loop cond (#24049)

* Simplify code for gast.If in is_control_flow_to_transform.
* Move IsControlFlowVisitor to file utils. 
* Don't use convert_call for build-in func in CallTransformer. 
* Optimize api is_control_flow_to_transform. 
* Polish the document of IsControlFlowVisitor.
revert-22778-infer_var_type
liym27 5 years ago committed by GitHub
parent aa0330f451
commit 2961a4f07d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,6 +32,15 @@ class CallTransformer(gast.NodeTransformer):
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def _is_builtin_call(self, node):
assert isinstance(node, gast.Call)
func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
return eval("is_builtin({})".format(func_str))
except Exception:
return False
def transform(self):
self.visit(self.root)
@ -39,6 +48,10 @@ class CallTransformer(gast.NodeTransformer):
self.generic_visit(node)
if is_paddle_api(node):
return node
if self._is_builtin_call(node):
return node
func_str = ast_to_source_code(node.func).strip()
new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
func_str)

@ -25,12 +25,15 @@ import gast
import six
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none
from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
@ -142,145 +145,6 @@ class IfElseTransformer(gast.NodeTransformer):
return self.new_func_nodes
def is_candidate_node(node):
"""
Nodes with specified type will be dependent on tensor.
"""
is_compare_node = isinstance(node,
(gast.Compare, gast.BoolOp, gast.UnaryOp))
# TODO(Aurelius84): `.numpy()` may be an customized function,
# and should consider a more elegant way to solve this problem.
has_numpy_attr = ".numpy()" in ast_to_source_code(node)
return is_compare_node or has_numpy_attr
def compare_with_none(node):
"""
Whether the comparator of `gast.Compare` node is `None`.
"""
if isinstance(node, gast.Compare):
for child in [node.left, node.comparators]:
# node.comparators is a list.
if isinstance(child, list):
child = child[0]
if (isinstance(child, gast.Constant) and child.value is None) or (
isinstance(child, gast.Name) and child.id == 'None'):
return True
return False
class IsControlFlowVisitor(gast.NodeVisitor):
"""
Judge whether the node.test from Dygraph code dependent on paddle Tensor.
If does, it should satisfy:
1. must involve at least one var whose type is Tensor.
2. the Tensor var should call `.numpy()[]` interface or Tensor.shape is [1].
3. involve Tensor.shape[i] and the shape[i] is unknown in compile time.
The following examples should not be considered as control_flow_if:
1. `if Tensor_var` or `if Tensor_var is None`
2. if Tensor.shape[i] is determined with fixed value (not -1 or None)
Note: pred in ConditionalBlock require variable, which means all vars should be Tensor
or transformed into Tensor, like fill_constant(shape=[1], dtype='int32', value=Tensor.shape[i]).
TODO: 1. need to deal with `tensor.shape[i]` which need to eval the data of shape[i],
because reshape_op may be called before this statement.
"""
def __init__(self,
ast_node,
static_analysis_visitor=None,
node_var_type_map=None):
assert isinstance(
ast_node, gast.AST
), "Type of input node should be gast.AST, but received %s." % type(
ast_node)
self.ast_root = ast_node
if static_analysis_visitor is None:
static_analysis_visitor = StaticAnalysisVisitor(ast_node)
self.static_analysis_visitor = static_analysis_visitor
self.node_var_type_map = node_var_type_map
self.is_control_flow_num = 0
self._compare_node_tenor_set = set()
def transform(self):
node = self.ast_root
if is_candidate_node(node):
self.visit(node)
return self.is_control_flow_num > 0
def visit_BoolOp(self, node):
for i, child in enumerate(node.values):
if is_candidate_node(child):
self.visit(child)
return node
def visit_Compare(self, node):
# Ignores child node with `if x` or `if x is None`
# TODO(Aurelius84): `if tensor` will be supported in dygraph
# and should be considered as is_control_flow.
pre_control_flow_num = self.is_control_flow_num
if not compare_with_none(node):
self.generic_visit(node)
for child in gast.walk(node):
if isinstance(child, gast.Subscript):
self._visit_Subscript(child)
if self.is_control_flow_num > pre_control_flow_num:
self._compare_node_tenor_set.add(node)
return node
def _visit_Subscript(self, node):
self.generic_visit(node)
if hasattr(node, 'value') and isinstance(node.value, gast.Call):
self._visit_Call(node.value)
return node
def _visit_Call(self, node):
assert isinstance(node, gast.Call)
if isinstance(node.func, gast.Attribute):
attr_node = node.func
if attr_node.attr == 'numpy':
self.is_control_flow_num += 1
def visit_Call(self, node):
self._visit_Call(node)
if is_paddle_api(node):
self.is_control_flow_num += 1
return node
def visit_Name(self, node):
if self._is_node_with_tensor(node, node.id):
self.is_control_flow_num += 1
return node
def visit_Constant(self, node):
if self._is_node_with_tensor(node, node.value):
self.is_control_flow_num += 1
return node
def _is_node_with_tensor(self, node, name_id):
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
# Look up the node_var_type_map by name_id.
if self.node_var_type_map:
if name_id and isinstance(name_id, six.string_types):
var_type = self.node_var_type_map.get(name_id, None)
if var_type and var_type & tensor_types:
return True
# if not found, look up the node_to_wrapper_map by node.
node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
)
wrapper_node = node_to_wrapper_map.get(node, None)
if wrapper_node is not None:
if wrapper_node.node_var_type & tensor_types:
return True
return False
def get_compare_nodes_with_tensor(self):
return self._compare_node_tenor_set
class NodeTestTransformer(gast.NodeTransformer):
def __init__(self, ast_node, compare_nodes_with_tensor=None):
if compare_nodes_with_tensor is None:

@ -55,19 +55,22 @@ class ListTransformer(gast.NodeTransformer):
def visit_If(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict):
if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def visit_While(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict):
if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node
def visit_For(self, node):
self.generic_visit(node)
if is_control_flow_to_transform(node, self.scope_var_type_dict):
if is_control_flow_to_transform(node, self.static_analysis_visitor,
self.scope_var_type_dict):
self._transform_list_append_in_control_flow(node)
return node

@ -26,6 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
@ -150,8 +151,9 @@ class NameVisitor(gast.NodeVisitor):
self.visit(root_node)
def is_control_flow_loop(self, node):
# TODO: make a better condition
return True
need_transform = is_control_flow_to_transform(
node, self.static_analysis_visitor)
return need_transform
def get_loop_var_names(self, node):
assert isinstance(

@ -15,15 +15,8 @@
from __future__ import print_function
import gast
import astor
import copy
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor

File diff suppressed because it is too large Load Diff

@ -19,9 +19,9 @@ import textwrap
import gast
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import get_name_ids
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfConditionVisitor
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IsControlFlowVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor
class TestGetNameIds(unittest.TestCase):

@ -47,6 +47,10 @@ def test_list_in_if(x):
def test_list_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32"
) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved
a = []
for i in range(iter_num):
a.append(x)
@ -56,6 +60,10 @@ def test_list_in_for_loop(x, iter_num):
def test_list_in_for_loop_with_concat(x, iter_num):
x = fluid.dygraph.to_variable(x)
a = []
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32"
) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved
for i in range(iter_num):
a.append(x)
a = fluid.layers.concat(a, axis=0)

@ -29,6 +29,9 @@ np.random.seed(SEED)
def while_loop_dyfunc(x):
i = fluid.dygraph.to_variable(x)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
while x < 10:
i = i + x
x = x + 1
@ -37,6 +40,9 @@ def while_loop_dyfunc(x):
def while_loop_dyfun_with_conflict_var(x):
i = fluid.dygraph.to_variable(x)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
def relu(y):
# 'y' is not visible outside the scope.
@ -56,6 +62,9 @@ def while_loop_dyfunc_with_none(x):
i = fluid.dygraph.to_variable(x)\
if x is not None \
else fluid.dygraph.to_variable(x+1)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
flag = 1
while x < 10:
i = i + x if flag is not None else x + i
@ -72,6 +81,10 @@ def for_loop_dyfunc(max_len):
def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x)
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
x = fluid.dygraph.to_variable(
x) # TODO(liym27): Delete it if the type of parameter x can be resolved
while (x >= 0 and x < 10) or x <= -1 or x < -3 or (x < -7 or x < -5):
i = i + x
x = x + 1
@ -102,6 +115,11 @@ def for_loop_class_var(max_len):
self.c = 5
foo = Foo()
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
# TODO(liym27): Delete it if the type of parameter x can be resolved
max_len = fluid.layers.fill_constant(
shape=[1], value=max_len, dtype="int32")
for i in range(max_len):
foo.b = fluid.layers.zeros(shape=[1], dtype='float32')
foo.c = foo.b + foo.a

@ -69,6 +69,11 @@ def test_slice_in_while_loop(x, iter_num):
def test_slice_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
a = []
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
iter_num = fluid.layers.fill_constant(
shape=[1], value=iter_num, dtype="int32"
) # TODO(liym27): Delete it if the type of parameter iter_num can be resolved
for i in range(iter_num):
a.append(x)

Loading…
Cancel
Save