[dy2static] Add print transformer and unify print format (#24068)
* add print transformer & unify print format, test=develop * remove using of dygraph_to_static_func, test=develop * remove python stdout capture, test=develop * fix compatibility problems for PY2, test=develop * fix detail error, test=develop * fix type analysis bug, test=develop * fix print tuple compatible error in PY2, test=develop * replace get_func to declarative, test=develop * fix detail bug, test=develop * fix some detail problems, test=develop * change visit_call in print transformer, test=developrevert-24314-dev/fix_err_msg
parent
3e962aecc1
commit
9b851ba216
@ -0,0 +1,89 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
import astor
|
||||
|
||||
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
|
||||
|
||||
|
||||
class PrintTransformer(gast.NodeTransformer):
|
||||
"""
|
||||
This class transforms python print function to fluid.layers.Print.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapper_root):
|
||||
assert isinstance(
|
||||
wrapper_root, AstNodeWrapper
|
||||
), "Input non-AstNodeWrapper node for the initialization of PrintTransformer."
|
||||
self.wrapper_root = wrapper_root
|
||||
self.root = wrapper_root.node
|
||||
|
||||
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
|
||||
self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map(
|
||||
)
|
||||
|
||||
def transform(self):
|
||||
self.visit(self.root)
|
||||
|
||||
# NOTE: deal with print in PY3
|
||||
def visit_Call(self, node):
|
||||
assert isinstance(node, gast.Call)
|
||||
if isinstance(node.func, gast.Name) and node.func.id == 'print':
|
||||
var = self._get_print_var(node)
|
||||
return self._construct_print_node(var)
|
||||
return node
|
||||
|
||||
# NOTE: deal with print in PY2
|
||||
def visit_Print(self, node):
|
||||
var = self._get_print_var(node)
|
||||
print_call_node = self._construct_print_node(var)
|
||||
return gast.Expr(value=print_call_node)
|
||||
|
||||
def _get_print_var(self, node):
|
||||
if isinstance(node, gast.Call):
|
||||
var_list = node.args
|
||||
elif isinstance(node, gast.Print):
|
||||
var_list = node.values
|
||||
if isinstance(var_list[0], gast.Tuple):
|
||||
var_list = var_list[0].elts
|
||||
# TODO: support print multiple Var
|
||||
assert len(var_list) == 1, "Now only support print one Variable."
|
||||
return var_list[0]
|
||||
|
||||
def _construct_print_node(self, node):
|
||||
if isinstance(node, gast.Name):
|
||||
if self._is_tensor_node(node):
|
||||
print_node = gast.Call(
|
||||
func=gast.parse('fluid.layers.Print').body[0].value,
|
||||
args=[node],
|
||||
keywords=[])
|
||||
return print_node
|
||||
else:
|
||||
raise TypeError(
|
||||
"print object type error, only support print Variable now.")
|
||||
else:
|
||||
# TODO: may not only print with format
|
||||
raise NotImplementedError(
|
||||
"cannot transform print with format temporarily.")
|
||||
|
||||
def _is_tensor_node(self, node):
|
||||
tensor_types = {NodeVarType.TENSOR, NodeVarType.PADDLE_RETURN_TYPES}
|
||||
wrapper_node = self.node_to_wrapper_map.get(node, None)
|
||||
if wrapper_node is not None:
|
||||
if wrapper_node.node_var_type & tensor_types:
|
||||
return True
|
||||
return False
|
@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy
|
||||
import unittest
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.dygraph.jit import declarative
|
||||
|
||||
|
||||
# 1. print VarBase
|
||||
@declarative
|
||||
def dyfunc_print_variable(x):
|
||||
"""
|
||||
PY2:
|
||||
Print(dest=None, values=[Name(id='x_v', annotation=None, type_comment=None)], nl=True)],
|
||||
PY3:
|
||||
Expr(
|
||||
value=Call(func=Name(id='print', annotation=None, type_comment=None),
|
||||
args=[Name(id='x_v', annotation=None, type_comment=None)],
|
||||
keywords=[]))
|
||||
"""
|
||||
# NOTE: transform to static code, var name will be changed
|
||||
x_v = fluid.dygraph.to_variable(x)
|
||||
print(x_v)
|
||||
|
||||
|
||||
# 2. print ndarray
|
||||
@declarative
|
||||
def dyfunc_print_ndarray(x):
|
||||
"""
|
||||
PY2:
|
||||
Print(dest=None, values=[Name(id='x', annotation=None, type_comment=None)
|
||||
PY3:
|
||||
Expr(
|
||||
value=Call(func=Name(id='print', annotation=None, type_comment=None),
|
||||
args=[Name(id='x', annotation=None, type_comment=None)],
|
||||
keywords=[]))
|
||||
"""
|
||||
print(x)
|
||||
|
||||
|
||||
# 3. print VarBase with format
|
||||
@declarative
|
||||
def dyfunc_print_with_format(x):
|
||||
"""
|
||||
PY2:
|
||||
Print(dest=None,
|
||||
values=[
|
||||
Call(
|
||||
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
|
||||
args=[Name(id='x_v', annotation=None, type_comment=None)],
|
||||
keywords=[])],
|
||||
nl=True)
|
||||
PY3:
|
||||
Expr(
|
||||
value=Call(func=Name(id='print', annotation=None, type_comment=None),
|
||||
args=[
|
||||
Call(
|
||||
func=Attribute(value=Constant(value='PrintVariable: {}', kind=None), attr='format'),
|
||||
args=[Name(id='x_v', annotation=None, type_comment=None)],
|
||||
keywords=[])],
|
||||
keywords=[]))
|
||||
"""
|
||||
x_v = fluid.dygraph.to_variable(x)
|
||||
print("PrintVariable: {}".format(x_v))
|
||||
|
||||
|
||||
# 4. print VarBase with format 2
|
||||
@declarative
|
||||
def dyfunc_print_with_format2(x):
|
||||
"""
|
||||
PY2:
|
||||
Print(dest=None,
|
||||
values=[
|
||||
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
|
||||
op=Mod,
|
||||
right=Name(id='x_v', annotation=None, type_comment=None))],
|
||||
nl=True)
|
||||
PY3:
|
||||
Expr(
|
||||
value=Call(func=Name(id='print', annotation=None, type_comment=None),
|
||||
args=[
|
||||
BinOp(left=Constant(value='PrintVariable: %s', kind=None),
|
||||
op=Mod,
|
||||
right=Name(id='x_v', annotation=None, type_comment=None))],
|
||||
keywords=[]))
|
||||
"""
|
||||
x_v = fluid.dygraph.to_variable(x)
|
||||
print("PrintVariable: %s" % (x_v))
|
||||
|
||||
|
||||
# 5. print VarBase in control flow1
|
||||
@declarative
|
||||
def dyfunc_print_with_ifelse(x):
|
||||
x_v = fluid.dygraph.to_variable(x)
|
||||
if len(x_v.shape) > 1:
|
||||
print(x_v)
|
||||
else:
|
||||
print(x_v)
|
||||
|
||||
|
||||
# 6. print mutiple VarBases
|
||||
@declarative
|
||||
def dyfunc_print_multi_vars(x):
|
||||
"""
|
||||
# NOTE: y_v type is error before cur PR in this case
|
||||
Assign(targets=[Name(id='y_v', annotation=None, type_comment=None)],
|
||||
value=BinOp(left=Name(id='x_v', annotation=None, type_comment=None), op=Mult, right=Constant(value=2, kind=None)))
|
||||
"""
|
||||
x_v = fluid.dygraph.to_variable(x)
|
||||
y_v = x_v * 2
|
||||
print(x_v)
|
||||
print(y_v)
|
||||
|
||||
|
||||
# 7. print continue VarBase
|
||||
@declarative
|
||||
def dyfunc_print_continue_vars(x):
|
||||
"""
|
||||
PY3:
|
||||
Expr(
|
||||
value=Call(func=Name(id='print', annotation=None, type_comment=None),
|
||||
args=[Name(id='x_v', annotation=None, type_comment=None),
|
||||
Name(id='y_v', annotation=None, type_comment=None)],
|
||||
keywords=[]))
|
||||
PY2:
|
||||
Print(dest=None,
|
||||
values=[
|
||||
Tuple(
|
||||
elts=[Name(id='x_v', annotation=None, type_comment=None),
|
||||
Name(id='y_v', annotation=None, type_comment=None)])],
|
||||
nl=True)
|
||||
"""
|
||||
x_v = fluid.dygraph.to_variable(x)
|
||||
y_v = x_v * 2
|
||||
print(x_v, y_v)
|
||||
|
||||
|
||||
class TestPrintBase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.input = numpy.ones(5).astype("int32")
|
||||
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
self.set_test_func()
|
||||
|
||||
def set_test_func(self):
|
||||
raise NotImplementedError("Print test should implement set_test_func")
|
||||
|
||||
def get_dygraph_output(self):
|
||||
with fluid.dygraph.guard():
|
||||
self.dygraph_func(self.input)
|
||||
|
||||
def get_static_output(self):
|
||||
with fluid.program_guard(fluid.Program()):
|
||||
# TODO: How to catch C++ stdout to python
|
||||
self.dygraph_func(self.input)
|
||||
|
||||
|
||||
class TestPrintVariable(TestPrintBase):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_variable
|
||||
|
||||
def test_transformed_static_result(self):
|
||||
self.get_dygraph_output()
|
||||
self.get_static_output()
|
||||
|
||||
|
||||
class TestPrintNdArray(TestPrintBase):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_ndarray
|
||||
|
||||
def test_transform_static_error(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.get_dygraph_output()
|
||||
self.get_static_output()
|
||||
|
||||
|
||||
class TestPrintWithFormat(TestPrintBase):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_with_format
|
||||
|
||||
def test_transform_static_error(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.get_dygraph_output()
|
||||
self.get_static_output()
|
||||
|
||||
|
||||
class TestPrintWithFormat2(TestPrintBase):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_with_format2
|
||||
|
||||
def test_transform_static_error(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.get_dygraph_output()
|
||||
self.get_static_output()
|
||||
|
||||
|
||||
class TestPrintWithIfElse(TestPrintVariable):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_with_ifelse
|
||||
|
||||
|
||||
class TestPrintMultipleVar(TestPrintVariable):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_multi_vars
|
||||
|
||||
|
||||
class TestPrintContinueVar(TestPrintBase):
|
||||
def set_test_func(self):
|
||||
self.dygraph_func = dyfunc_print_continue_vars
|
||||
|
||||
def test_transform_static_error(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
self.get_dygraph_output()
|
||||
self.get_static_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue