Support "while" in Dygraph to Static (#22841)

Add basic support for while in translating dygraph to static

1. Analysis the variable liveness in class NameVisitor
2. Replace while key word using while_loop API
revert-22710-feature/integrated_ps_api
Huihuang Zheng 5 years ago committed by GitHub
parent b6717faf80
commit aca3f5311d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,10 +20,18 @@ from .ast_transformer import *
from . import static_analysis
from .static_analysis import *
from . import loop_transformer
from .loop_transformer import *
from . import variable_trans_func
from .variable_trans_func import *
from . import cache_program
from .cache_program import *
__all__ = []
__all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__
__all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += cache_program.__all__

@ -13,17 +13,21 @@
# limitations under the License.
from __future__ import print_function
from .utils import *
import gast
import textwrap
import inspect
import astor
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
import gast
import textwrap
import inspect
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from .ast_utils import is_control_flow_if, create_cond_node, transform_if_else, ast_to_func
from .static_analysis import AstNodeWrapper, StaticAnalysisVisitor
from .utils import *
__all__ = ['DygraphToStaticAst', 'convert_to_static']
@ -124,17 +128,19 @@ class DygraphToStaticAst(gast.NodeTransformer):
self.transfer_from_node_type(self.static_analysis_root)
return self.static_analysis_root
def transfer_from_node_type(self, node):
def transfer_from_node_type(self, node_wrapper):
# Generic transformation
self.visit(node.node)
self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph
basic_api_trans = BasicApiTransformer(node)
basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans.ast_visit()
self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id()
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node).ast_visit()
IfElseTransformer(node_wrapper).ast_visit()
LoopTransformer(node_wrapper).transform()
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:

@ -14,8 +14,8 @@
from __future__ import print_function
import astor
import ast
import astor
import gast
import six
import copy

@ -0,0 +1,46 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
from paddle.fluid.layers import fill_constant
__all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node']
def to_static_variable_gast_node(name):
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format(
name, name)
return gast.parse(func_code)
def create_static_variable_gast_node(name):
func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format(
name, name)
return gast.parse(func_code)
def to_static_variable(x):
'''
Translate a Python variable to PaddlePaddle static graph variable
'''
if isinstance(x, bool):
return fill_constant(shape=[1], dtype='bool', value=x)
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x)
if isinstance(x, float):
return fill_constant(shape=[1], dtype='float64', value=x)
return x

@ -0,0 +1,80 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import inspect
import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
#from paddle.fluid.dygraph.dygraph_to_static import NameVistor
SEED = 2020
np.random.seed(SEED)
def while_loop_dyfunc(x):
i = fluid.dygraph.to_variable(x)
while x < 10:
i = i + x
x = x + 1
return i
class TestNameVisitor(unittest.TestCase):
def test_loop_vars(self):
#TODO
pass
class TestTransformWhile(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.x = np.zeros(shape=(1), dtype=np.int32)
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_var = fluid.layers.assign(self.x)
static_func = dygraph_to_static_graph(while_loop_dyfunc)
out = static_func(x_var)
exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out)
return ret
def _run_dygraph(self):
with fluid.dygraph.guard(self.place):
ret = while_loop_dyfunc(fluid.dygraph.to_variable(self.x))
return ret.numpy()
def test_ast_to_func(self):
static_numpy = self._run_static()
self.assertTrue(
np.allclose(
np.full(
shape=(1), fill_value=45, dtype=np.int32), static_numpy))
# Enable next lines after Paddle dygraph supports while x < 10
#
# self._run_dygraph()
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save