|
|
|
@ -47,84 +47,17 @@ class PrintTransformer(gast.NodeTransformer):
|
|
|
|
|
# NOTE: deal with print in PY3
|
|
|
|
|
def visit_Call(self, node):
|
|
|
|
|
if isinstance(node.func, gast.Name) and node.func.id == 'print':
|
|
|
|
|
parent_node = self.node_to_wrapper_map[node].parent.node
|
|
|
|
|
if isinstance(parent_node, gast.Expr):
|
|
|
|
|
# NOTE: why need transform to gast.Assign node
|
|
|
|
|
# only fluid.layers.Print(x) will be pruned when exe.run(use_prune=True)
|
|
|
|
|
print_assign_node = self._create_assign_node(node)
|
|
|
|
|
if print_assign_node is not None:
|
|
|
|
|
return print_assign_node
|
|
|
|
|
else:
|
|
|
|
|
return self._transform_call_node(node)
|
|
|
|
|
convert_print_node = self._create_print_node(node.args)
|
|
|
|
|
return gast.Expr(value=convert_print_node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
# NOTE: deal with print in PY2
|
|
|
|
|
def visit_Print(self, node):
|
|
|
|
|
print_assign_node = self._create_assign_node(node)
|
|
|
|
|
if print_assign_node is not None:
|
|
|
|
|
return print_assign_node
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _transform_call_node(self, node):
|
|
|
|
|
assert isinstance(node, gast.Call), "visit Node is not gast.Call node."
|
|
|
|
|
var_node = self._get_print_var_node(node)
|
|
|
|
|
if var_node is None:
|
|
|
|
|
return node
|
|
|
|
|
if self._need_transform(var_node, node):
|
|
|
|
|
return self._build_print_call_node(var_node)
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def _create_assign_node(self, node):
|
|
|
|
|
var_node = self._get_print_var_node(node)
|
|
|
|
|
if var_node is None:
|
|
|
|
|
return None
|
|
|
|
|
if self._need_transform(var_node, node):
|
|
|
|
|
return gast.Assign(
|
|
|
|
|
targets=[var_node], value=self._build_print_call_node(var_node))
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _build_print_call_node(self, node):
|
|
|
|
|
return gast.Call(
|
|
|
|
|
func=gast.parse('fluid.layers.Print').body[0].value,
|
|
|
|
|
args=[node],
|
|
|
|
|
keywords=[
|
|
|
|
|
gast.keyword(
|
|
|
|
|
arg='summarize',
|
|
|
|
|
value=gast.UnaryOp(
|
|
|
|
|
op=gast.USub(),
|
|
|
|
|
operand=gast.Constant(
|
|
|
|
|
value=1, kind=None))), gast.keyword(
|
|
|
|
|
arg='print_phase',
|
|
|
|
|
value=gast.Constant(
|
|
|
|
|
value='forward', kind=None))
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def _get_print_var_node(self, node):
|
|
|
|
|
if isinstance(node, gast.Call):
|
|
|
|
|
var_list = node.args
|
|
|
|
|
elif isinstance(node, gast.Print):
|
|
|
|
|
var_list = node.values
|
|
|
|
|
if isinstance(var_list[0], gast.Tuple):
|
|
|
|
|
var_list = var_list[0].elts
|
|
|
|
|
# TODO: support print multiple Var
|
|
|
|
|
if len(var_list) == 1:
|
|
|
|
|
return var_list[0]
|
|
|
|
|
else:
|
|
|
|
|
_logger.warning(
|
|
|
|
|
"ProgramTranslator could not transform printing multiple values like < %s > now and will run it as-is."
|
|
|
|
|
% ast_to_source_code(node).strip())
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def _need_transform(self, var_node, print_node):
|
|
|
|
|
if isinstance(var_node, gast.Name):
|
|
|
|
|
if self.static_analysis_visitor.is_tensor_node(var_node):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
_logger.warning(
|
|
|
|
|
"ProgramTranslator could not transform printing value that are not Tensor like < %s > now and will run it as-is."
|
|
|
|
|
% ast_to_source_code(print_node).strip())
|
|
|
|
|
else:
|
|
|
|
|
_logger.warning(
|
|
|
|
|
"ProgramTranslator could not transform < %s > now and will run it as-is."
|
|
|
|
|
% ast_to_source_code(print_node).strip())
|
|
|
|
|
return False
|
|
|
|
|
convert_print_node = self._create_print_node(node.values)
|
|
|
|
|
return gast.Expr(value=convert_print_node)
|
|
|
|
|
|
|
|
|
|
def _create_print_node(self, print_args):
|
|
|
|
|
convert_print_func = gast.parse(
|
|
|
|
|
'fluid.dygraph.dygraph_to_static.convert_operators.convert_print'
|
|
|
|
|
).body[0].value
|
|
|
|
|
return gast.Call(func=convert_print_func, args=print_args, keywords=[])
|
|
|
|
|