Support "while" in Dygraph to Static (#22841)
Add basic support for while in translating dygraph to static 1. Analysis the variable liveness in class NameVisitor 2. Replace while key word using while_loop APIrevert-22710-feature/integrated_ps_api
parent
b6717faf80
commit
aca3f5311d
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
|
||||
from paddle.fluid.layers import fill_constant
|
||||
|
||||
__all__ = ['to_static_variable_gast_node', 'create_static_variable_gast_node']
|
||||
|
||||
|
||||
def to_static_variable_gast_node(name):
|
||||
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})".format(
|
||||
name, name)
|
||||
return gast.parse(func_code)
|
||||
|
||||
|
||||
def create_static_variable_gast_node(name):
|
||||
func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format(
|
||||
name, name)
|
||||
return gast.parse(func_code)
|
||||
|
||||
|
||||
def to_static_variable(x):
|
||||
'''
|
||||
Translate a Python variable to PaddlePaddle static graph variable
|
||||
'''
|
||||
if isinstance(x, bool):
|
||||
return fill_constant(shape=[1], dtype='bool', value=x)
|
||||
if isinstance(x, int):
|
||||
return fill_constant(shape=[1], dtype='int64', value=x)
|
||||
if isinstance(x, float):
|
||||
return fill_constant(shape=[1], dtype='float64', value=x)
|
||||
return x
|
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import gast
|
||||
import inspect
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import unittest
|
||||
|
||||
from paddle.fluid.dygraph.jit import dygraph_to_static_graph
|
||||
#from paddle.fluid.dygraph.dygraph_to_static import NameVistor
|
||||
|
||||
SEED = 2020
|
||||
np.random.seed(SEED)
|
||||
|
||||
|
||||
def while_loop_dyfunc(x):
|
||||
i = fluid.dygraph.to_variable(x)
|
||||
while x < 10:
|
||||
i = i + x
|
||||
x = x + 1
|
||||
return i
|
||||
|
||||
|
||||
class TestNameVisitor(unittest.TestCase):
|
||||
def test_loop_vars(self):
|
||||
#TODO
|
||||
pass
|
||||
|
||||
|
||||
class TestTransformWhile(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
self.x = np.zeros(shape=(1), dtype=np.int32)
|
||||
|
||||
def _run_static(self):
|
||||
main_program = fluid.Program()
|
||||
with fluid.program_guard(main_program):
|
||||
x_var = fluid.layers.assign(self.x)
|
||||
static_func = dygraph_to_static_graph(while_loop_dyfunc)
|
||||
|
||||
out = static_func(x_var)
|
||||
exe = fluid.Executor(self.place)
|
||||
ret = exe.run(main_program, fetch_list=out)
|
||||
return ret
|
||||
|
||||
def _run_dygraph(self):
|
||||
with fluid.dygraph.guard(self.place):
|
||||
ret = while_loop_dyfunc(fluid.dygraph.to_variable(self.x))
|
||||
return ret.numpy()
|
||||
|
||||
def test_ast_to_func(self):
|
||||
static_numpy = self._run_static()
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
np.full(
|
||||
shape=(1), fill_value=45, dtype=np.int32), static_numpy))
|
||||
|
||||
# Enable next lines after Paddle dygraph supports while x < 10
|
||||
#
|
||||
# self._run_dygraph()
|
||||
# self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue