From 60da88540f9083d08d6c27b501e35d6be26fa660 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Tue, 21 Jul 2020 11:17:49 +0800 Subject: [PATCH] [Dy2stat] Modify print for dynamic type (#25612) Modify the print in Dy2stat for dynamic type. Unit test is covered in old test_print.py --- .../dygraph_to_static/convert_operators.py | 18 +++- .../dygraph_to_static/print_transformer.py | 87 +++---------------- 2 files changed, 26 insertions(+), 79 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 1291be60c6..02d8754e62 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -14,8 +14,9 @@ from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable -from paddle.fluid.framework import Variable, core -from paddle.fluid.layers import Assert, cast, control_flow, logical_and, logical_not, logical_or, nn +from paddle.fluid.framework import core, Variable +from paddle.fluid.layers import Assert, Print +from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn def convert_while_loop(cond, body, loop_vars): @@ -271,3 +272,16 @@ def convert_assert(cond, message=""): return Assert(cond) else: assert cond, message + + +def convert_print(*args): + """ + A function representing Python ``print`` statement. Note: this is a basic + python function so we haven't handle sep, end, file and flush parameters of + python function. + """ + for var in args: + if isinstance(var, Variable): + var = Print(var) + else: + print(var) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py index e55018d2e7..1b6b64ae1f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py @@ -47,84 +47,17 @@ class PrintTransformer(gast.NodeTransformer): # NOTE: deal with print in PY3 def visit_Call(self, node): if isinstance(node.func, gast.Name) and node.func.id == 'print': - parent_node = self.node_to_wrapper_map[node].parent.node - if isinstance(parent_node, gast.Expr): - # NOTE: why need transform to gast.Assign node - # only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True) - print_assign_node = self._create_assign_node(node) - if print_assign_node is not None: - return print_assign_node - else: - return self._transform_call_node(node) + convert_print_node = self._create_print_node(node.args) + return gast.Expr(value=convert_print_node) return node # NOTE: deal with print in PY2 def visit_Print(self, node): - print_assign_node = self._create_assign_node(node) - if print_assign_node is not None: - return print_assign_node - return node - - def _transform_call_node(self, node): - assert isinstance(node, gast.Call), "visit Node is not gast.Call node." - var_node = self._get_print_var_node(node) - if var_node is None: - return node - if self._need_transform(var_node, node): - return self._build_print_call_node(var_node) - return node - - def _create_assign_node(self, node): - var_node = self._get_print_var_node(node) - if var_node is None: - return None - if self._need_transform(var_node, node): - return gast.Assign( - targets=[var_node], value=self._build_print_call_node(var_node)) - return None - - def _build_print_call_node(self, node): - return gast.Call( - func=gast.parse('fluid.layers.Print').body[0].value, - args=[node], - keywords=[ - gast.keyword( - arg='summarize', - value=gast.UnaryOp( - op=gast.USub(), - operand=gast.Constant( - value=1, kind=None))), gast.keyword( - arg='print_phase', - value=gast.Constant( - value='forward', kind=None)) - ]) - - def _get_print_var_node(self, node): - if isinstance(node, gast.Call): - var_list = node.args - elif isinstance(node, gast.Print): - var_list = node.values - if isinstance(var_list[0], gast.Tuple): - var_list = var_list[0].elts - # TODO: support print multiple Var - if len(var_list) == 1: - return var_list[0] - else: - _logger.warning( - "ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is." - % ast_to_source_code(node).strip()) - return None - - def _need_transform(self, var_node, print_node): - if isinstance(var_node, gast.Name): - if self.static_analysis_visitor.is_tensor_node(var_node): - return True - else: - _logger.warning( - "ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is." - % ast_to_source_code(print_node).strip()) - else: - _logger.warning( - "ProgramTranslator could not transform < %s > now and will run it as-is." - % ast_to_source_code(print_node).strip()) - return False + convert_print_node = self._create_print_node(node.values) + return gast.Expr(value=convert_print_node) + + def _create_print_node(self, print_args): + convert_print_func = gast.parse( + 'fluid.dygraph.dygraph_to_static.convert_operators.convert_print' + ).body[0].value + return gast.Call(func=convert_print_func, args=print_args, keywords=[])