[Dy2stat] Add Basic Support for Grammar 'return' (#25176)
This PR added basic support for 'return' grammar in dy2stat. It supports the control flow of 'return'. The basics idea is using a return value variable to store the early return statements and boolean state variables with if-else to skip the statements after the return statements. **This PR is very basic support. There are some corner cases I didn't develop/test**. For example, 'return None', 'return different length of variables', 'return non-tensor and tensor together', 'no return statement'. **These corner cases will be done in my next PRs**. Target date is this week. **Note**: 1. for the unit test, I changed test_program_translator.py because the StaticCode of `dyfunc_with_if_else` will change. To guarantee the correctness of `dyfunc_with_if_else`, I also run it in `TestRecursiveReturn` in test_return.py. 2. I commented the early return code in bert_dygraph_model.py because 'return different length of variables' is unsupported now. I also know that there are some other models used early return and we didn't enable it in the unit test. I will add support for it in next PRs and then re-enable those tests.fix_copy_if_different
parent
1458cc0c68
commit
6f631a27c7
@ -0,0 +1,247 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from paddle.fluid import unique_name
|
||||||
|
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
|
||||||
|
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import ForToWhileTransformer
|
||||||
|
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
|
||||||
|
|
||||||
|
__all__ = ['ReturnTransformer']
|
||||||
|
|
||||||
|
# Constant for the name of the variable which stores the boolean state that we
|
||||||
|
# should return
|
||||||
|
RETURN_PREFIX = '__return'
|
||||||
|
|
||||||
|
# Constant for the name of the variable which stores the final return value
|
||||||
|
RETURN_VALUE_PREFIX = '__return_value'
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnPreAnalysisVisitor(gast.NodeVisitor):
|
||||||
|
"""
|
||||||
|
Visits gast Tree and pre-analyze the information about 'return'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root_node):
|
||||||
|
self.root = root_node
|
||||||
|
|
||||||
|
# A list to store where the current function is.
|
||||||
|
self.function_def = []
|
||||||
|
|
||||||
|
# Mapping from gast.FunctionDef node to the number of return statements
|
||||||
|
# Python allows define function inside function so we have to handle it
|
||||||
|
self.count_return = {}
|
||||||
|
self.visit(self.root)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
self.function_def.append(node)
|
||||||
|
self.count_return[node] = 0
|
||||||
|
self.generic_visit(node)
|
||||||
|
self.function_def.pop()
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Return(self, node):
|
||||||
|
assert len(
|
||||||
|
self.function_def) > 0, "Found 'return' statement out of function."
|
||||||
|
cur_func = self.function_def[-1]
|
||||||
|
if cur_func in self.count_return:
|
||||||
|
self.count_return[cur_func] += 1
|
||||||
|
else:
|
||||||
|
self.count_return[cur_func] = 1
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def get_func_return_count(self, func_node):
|
||||||
|
return self.count_return[func_node]
|
||||||
|
|
||||||
|
def set_func_return_count(self, func_node, count):
|
||||||
|
self.count_return[func_node] = count
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnTransformer(gast.NodeTransformer):
|
||||||
|
"""
|
||||||
|
Transforms return statements into equivalent python statements containing
|
||||||
|
only one return statement at last. The basics idea is using a return value
|
||||||
|
variable to store the early return statements and boolean states with
|
||||||
|
if-else to skip the statements after the return.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wrapper_root):
|
||||||
|
self.wrapper_root = wrapper_root
|
||||||
|
self.root = wrapper_root.node
|
||||||
|
self.ancestor_nodes = []
|
||||||
|
|
||||||
|
# The name of the variable which stores the final return value
|
||||||
|
# Mapping from FunctionDef node to string
|
||||||
|
self.return_value_name = {}
|
||||||
|
# The names of the variable which stores the boolean state that skip
|
||||||
|
# statments. Mapping from FunctionDef node to list
|
||||||
|
self.return_name = {}
|
||||||
|
# A list of FunctionDef to store where the current function is.
|
||||||
|
self.function_def = []
|
||||||
|
|
||||||
|
def transform(self):
|
||||||
|
self.visit(self.root)
|
||||||
|
|
||||||
|
def generic_visit(self, node):
|
||||||
|
# Because we change ancestor nodes during visit_Return, not current
|
||||||
|
# node, original generic_visit of NodeTransformer will visit node
|
||||||
|
# which may be deleted. To prevent that node being added into
|
||||||
|
# transformed AST, We self-write a generic_visit and visit
|
||||||
|
for field, value in gast.iter_fields(node):
|
||||||
|
if isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, gast.AST):
|
||||||
|
self.visit(item)
|
||||||
|
elif isinstance(value, gast.AST):
|
||||||
|
self.visit(value)
|
||||||
|
|
||||||
|
def visit(self, node):
|
||||||
|
"""
|
||||||
|
Self-defined visit for appending ancestor
|
||||||
|
"""
|
||||||
|
self.ancestor_nodes.append(node)
|
||||||
|
method = 'visit_' + node.__class__.__name__
|
||||||
|
visitor = getattr(self, method, self.generic_visit)
|
||||||
|
ret = visitor(node)
|
||||||
|
self.ancestor_nodes.pop()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node):
|
||||||
|
self.function_def.append(node)
|
||||||
|
self.return_value_name[node] = None
|
||||||
|
self.return_name[node] = []
|
||||||
|
|
||||||
|
pre_analysis = ReturnPreAnalysisVisitor(node)
|
||||||
|
while pre_analysis.get_func_return_count(node) > 1:
|
||||||
|
self.generic_visit(node)
|
||||||
|
pre_analysis = ReturnPreAnalysisVisitor(node)
|
||||||
|
|
||||||
|
# prepend initialization of final return and append final return statement
|
||||||
|
value_name = self.return_value_name[node]
|
||||||
|
if value_name is not None:
|
||||||
|
node.body.append(
|
||||||
|
gast.Return(value=gast.Name(
|
||||||
|
id=value_name,
|
||||||
|
ctx=gast.Load(),
|
||||||
|
annotation=None,
|
||||||
|
type_comment=None)))
|
||||||
|
assign_zero_node = create_fill_constant_node(value_name, 0.0)
|
||||||
|
node.body.insert(0, assign_zero_node)
|
||||||
|
# Prepend control flow boolean nodes such as '__return@1 = False'
|
||||||
|
for name in self.return_name[node]:
|
||||||
|
assign_false_node = create_fill_constant_node(name, False)
|
||||||
|
node.body.insert(0, assign_false_node)
|
||||||
|
|
||||||
|
self.function_def.pop()
|
||||||
|
return node
|
||||||
|
|
||||||
|
def visit_Return(self, node):
|
||||||
|
cur_func_node = self.function_def[-1]
|
||||||
|
return_name = unique_name.generate(RETURN_PREFIX)
|
||||||
|
self.return_name[cur_func_node].append(return_name)
|
||||||
|
for ancestor_index in reversed(range(len(self.ancestor_nodes) - 1)):
|
||||||
|
ancestor = self.ancestor_nodes[ancestor_index]
|
||||||
|
cur_node = self.ancestor_nodes[ancestor_index + 1]
|
||||||
|
if hasattr(ancestor,
|
||||||
|
"body") and index_in_list(ancestor.body, cur_node) != -1:
|
||||||
|
if cur_node == node:
|
||||||
|
self._replace_return_in_stmt_list(ancestor.body, cur_node,
|
||||||
|
return_name)
|
||||||
|
self._replace_after_node_to_if_in_stmt_list(
|
||||||
|
ancestor.body, cur_node, return_name)
|
||||||
|
elif hasattr(ancestor, "orelse") and index_in_list(ancestor.orelse,
|
||||||
|
cur_node) != -1:
|
||||||
|
if cur_node == node:
|
||||||
|
self._replace_return_in_stmt_list(ancestor.orelse, cur_node,
|
||||||
|
return_name)
|
||||||
|
self._replace_after_node_to_if_in_stmt_list(
|
||||||
|
ancestor.orelse, cur_node, return_name)
|
||||||
|
|
||||||
|
if isinstance(ancestor, gast.While):
|
||||||
|
cond_var_node = gast.UnaryOp(
|
||||||
|
op=gast.Not(),
|
||||||
|
operand=gast.Name(
|
||||||
|
id=return_name,
|
||||||
|
ctx=gast.Load(),
|
||||||
|
annotation=None,
|
||||||
|
type_comment=None))
|
||||||
|
ancestor.test = gast.BoolOp(
|
||||||
|
op=gast.And(), values=[ancestor.test, cond_var_node])
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(ancestor, gast.For):
|
||||||
|
cond_var_node = gast.UnaryOp(
|
||||||
|
op=gast.Not(),
|
||||||
|
operand=gast.Name(
|
||||||
|
id=return_name,
|
||||||
|
ctx=gast.Load(),
|
||||||
|
annotation=None,
|
||||||
|
type_comment=None))
|
||||||
|
parent_node = self.ancestor_nodes[ancestor_index - 1]
|
||||||
|
for_to_while = ForToWhileTransformer(parent_node, ancestor,
|
||||||
|
cond_var_node)
|
||||||
|
new_stmts = for_to_while.transform()
|
||||||
|
while_node = new_stmts[-1]
|
||||||
|
self.ancestor_nodes[ancestor_index] = while_node
|
||||||
|
|
||||||
|
if ancestor == cur_func_node:
|
||||||
|
break
|
||||||
|
# return_node is replaced so we shouldn't return here
|
||||||
|
|
||||||
|
def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name):
|
||||||
|
i = index_in_list(stmt_list, return_node)
|
||||||
|
if i == -1:
|
||||||
|
return False
|
||||||
|
assign_nodes = [create_fill_constant_node(return_name, True)]
|
||||||
|
if return_node.value is not None:
|
||||||
|
cur_func_node = self.function_def[-1]
|
||||||
|
if self.return_value_name[cur_func_node] is None:
|
||||||
|
self.return_value_name[cur_func_node] = unique_name.generate(
|
||||||
|
RETURN_VALUE_PREFIX)
|
||||||
|
assign_nodes.append(
|
||||||
|
gast.Assign(
|
||||||
|
targets=[
|
||||||
|
gast.Name(
|
||||||
|
id=self.return_value_name[cur_func_node],
|
||||||
|
ctx=gast.Store(),
|
||||||
|
annotation=None,
|
||||||
|
type_comment=None)
|
||||||
|
],
|
||||||
|
value=return_node.value))
|
||||||
|
stmt_list[i:] = assign_nodes
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _replace_after_node_to_if_in_stmt_list(self, stmt_list, node,
|
||||||
|
return_name):
|
||||||
|
i = index_in_list(stmt_list, node)
|
||||||
|
if i < 0 or i >= len(stmt_list):
|
||||||
|
return False
|
||||||
|
if i == len(stmt_list) - 1:
|
||||||
|
# No need to add, we consider this as added successfully
|
||||||
|
return True
|
||||||
|
if_stmt = gast.If(test=gast.UnaryOp(
|
||||||
|
op=gast.Not(),
|
||||||
|
operand=gast.Name(
|
||||||
|
id=return_name,
|
||||||
|
ctx=gast.Store(),
|
||||||
|
annotation=None,
|
||||||
|
type_comment=None)),
|
||||||
|
body=stmt_list[i + 1:],
|
||||||
|
orelse=[])
|
||||||
|
stmt_list[i + 1:] = [if_stmt]
|
||||||
|
return True
|
||||||
@ -0,0 +1,163 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.dygraph import declarative
|
||||||
|
from paddle.fluid.dygraph import ProgramTranslator
|
||||||
|
|
||||||
|
from ifelse_simple_func import dyfunc_with_if_else
|
||||||
|
|
||||||
|
SEED = 2020
|
||||||
|
np.random.seed(SEED)
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_return_base(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_inside_func_base(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
|
||||||
|
def inner_func(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
return inner_func(x)
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_return_if(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
if x < 0:
|
||||||
|
x -= 1
|
||||||
|
return -x
|
||||||
|
x += 3
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_return_if_else(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
if x > 0:
|
||||||
|
x += 10086
|
||||||
|
return x
|
||||||
|
x -= 3 # useless statement to test our code can handle it.
|
||||||
|
else:
|
||||||
|
x += 6666
|
||||||
|
return x
|
||||||
|
x -= 8888 # useless statement to test our code can handle it.
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_return_in_while(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
|
||||||
|
while i < 10:
|
||||||
|
i += 1
|
||||||
|
if i > 5:
|
||||||
|
x += 110
|
||||||
|
return x
|
||||||
|
x += i
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_return_in_for(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
for i in range(10):
|
||||||
|
if i <= 4:
|
||||||
|
x += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
return x + 10086
|
||||||
|
return x - 1
|
||||||
|
|
||||||
|
|
||||||
|
@declarative
|
||||||
|
def test_recursive_return(x):
|
||||||
|
x = fluid.dygraph.to_variable(x)
|
||||||
|
return dyfunc_with_if_else(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReturnBase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.input = np.ones((1)).astype('int32')
|
||||||
|
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||||
|
) else fluid.CPUPlace()
|
||||||
|
self.init_dygraph_func()
|
||||||
|
self.program_translator = ProgramTranslator()
|
||||||
|
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.dygraph_func = test_return_base
|
||||||
|
|
||||||
|
def run_dygraph_mode(self):
|
||||||
|
self.program_translator.enable(False)
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
res = self.dygraph_func(self.input)
|
||||||
|
return res.numpy()
|
||||||
|
|
||||||
|
def run_static_mode(self):
|
||||||
|
self.program_translator.enable(True)
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
res = self.dygraph_func(self.input)
|
||||||
|
return res.numpy()
|
||||||
|
|
||||||
|
def test_transformed_static_result(self):
|
||||||
|
static_res = self.run_static_mode()
|
||||||
|
dygraph_res = self.run_dygraph_mode()
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(dygraph_res, static_res),
|
||||||
|
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
|
||||||
|
static_res))
|
||||||
|
|
||||||
|
|
||||||
|
class TestInsideFuncBase(TestReturnBase):
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.dygraph_func = test_inside_func_base
|
||||||
|
|
||||||
|
|
||||||
|
class TestReturnIf(TestReturnBase):
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.dygraph_func = test_return_if
|
||||||
|
|
||||||
|
|
||||||
|
class TestReturnIfElse(TestReturnBase):
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.dygraph_func = test_return_if_else
|
||||||
|
|
||||||
|
|
||||||
|
class TestReturnInWhile(TestReturnBase):
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.dygraph_func = test_return_in_while
|
||||||
|
|
||||||
|
|
||||||
|
class TestReturnInFor(TestReturnBase):
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.dygraph_func = test_return_in_for
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecursiveReturn(TestReturnBase):
|
||||||
|
def init_dygraph_func(self):
|
||||||
|
self.input = self.input.astype(np.float32)
|
||||||
|
self.dygraph_func = test_recursive_return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue