Support if/else in dygraph_to_static (#22540)
* support nested if/else * support to derivate returns the parameter list automatically * polish tranform function of slice * fix modify x.numpy()[i] slice function * support to transform ast.node into callable function * fix get_name_ids bug and add more unittest test=develop * fix requirements.txt test=develop * remove useless import statement test=develop * Fixed version compatibility issues in param of function test=develop * use decorater to test ast_to_func test=develop * add textwrap.dedent for source_code test=develop * polish code comment * fix compatibility with python2 and python3 test=develop * fix gast version error test=develop * fix gast repo test=develop * polish transfer_from_node_type code test=develop * add nested_if_else unittest test=develop * split IfElseTransformer test=develop * specify gast version test=develop * fix ast_to_func root type test=developrevert-22710-feature/integrated_ps_api
parent
7a4c29e0e0
commit
08b09f6447
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,165 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import textwrap
|
||||||
|
import gast
|
||||||
|
import inspect
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid.dygraph.dygraph_to_static.ast_utils import get_name_ids, ast_to_func
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetNameIds(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test for parsing the ast.Name list from the ast.Nodes
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.source = """
|
||||||
|
def test_fn(x):
|
||||||
|
return x+1
|
||||||
|
"""
|
||||||
|
self.all_name_ids = {'x': [gast.Param()]}
|
||||||
|
|
||||||
|
def test_get_name_ids(self):
|
||||||
|
source = textwrap.dedent(self.source)
|
||||||
|
root = gast.parse(source)
|
||||||
|
all_name_ids = get_name_ids([root])
|
||||||
|
self.assertDictEqual(
|
||||||
|
self.transfer_dict(self.all_name_ids),
|
||||||
|
self.transfer_dict(all_name_ids))
|
||||||
|
|
||||||
|
def transfer_dict(self, name_ids_dict):
|
||||||
|
new_dict = {}
|
||||||
|
for name, ctxs in name_ids_dict.items():
|
||||||
|
new_dict[name] = [type(ctx) for ctx in ctxs]
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetNameIds2(TestGetNameIds):
|
||||||
|
def setUp(self):
|
||||||
|
self.source = """
|
||||||
|
def test_fn(x, y):
|
||||||
|
a = 1
|
||||||
|
x = y + a
|
||||||
|
if x > y:
|
||||||
|
z = x * x
|
||||||
|
z = z + a
|
||||||
|
else:
|
||||||
|
z = y * y
|
||||||
|
return z
|
||||||
|
"""
|
||||||
|
self.all_name_ids = {
|
||||||
|
'x': [
|
||||||
|
gast.Param(), gast.Store(), gast.Load(), gast.Load(),
|
||||||
|
gast.Load()
|
||||||
|
],
|
||||||
|
'a': [gast.Store(), gast.Load(), gast.Load()],
|
||||||
|
'y':
|
||||||
|
[gast.Param(), gast.Load(), gast.Load(), gast.Load(), gast.Load()],
|
||||||
|
'z': [gast.Store(), gast.Load(), gast.Store(), gast.Store()]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetNameIds3(TestGetNameIds):
|
||||||
|
def setUp(self):
|
||||||
|
self.source = """
|
||||||
|
def test_fn(x, y):
|
||||||
|
z = 1
|
||||||
|
if x > y:
|
||||||
|
z = x * x
|
||||||
|
z = z + y
|
||||||
|
return z
|
||||||
|
"""
|
||||||
|
self.all_name_ids = {
|
||||||
|
'x': [gast.Param(), gast.Load(), gast.Load(), gast.Load()],
|
||||||
|
'y': [gast.Param(), gast.Load(), gast.Load()],
|
||||||
|
'z': [gast.Store(), gast.Store(), gast.Load(), gast.Store()]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dyfunc_with_if_else(x_v):
|
||||||
|
if fluid.layers.mean(x_v).numpy()[0] > 5:
|
||||||
|
x_v = x_v - 1
|
||||||
|
else:
|
||||||
|
x_v = x_v + 1
|
||||||
|
return x_v
|
||||||
|
|
||||||
|
|
||||||
|
def dyfunc_with_if_else2(x):
|
||||||
|
i, j = 0, 0
|
||||||
|
if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[i][j]:
|
||||||
|
y = fluid.layers.relu(x)
|
||||||
|
else:
|
||||||
|
x_pow = fluid.layers.pow(x, 2)
|
||||||
|
y = fluid.layers.tanh(x_pow)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TestAST2Func(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
TestCase for the transformation from ast.AST into python callable function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _ast2func(self, func):
|
||||||
|
source = inspect.getsource(func)
|
||||||
|
source = textwrap.dedent(source)
|
||||||
|
ast_root = gast.parse(source)
|
||||||
|
transformed_func, _ = ast_to_func(ast_root, func.__name__)
|
||||||
|
return transformed_func
|
||||||
|
|
||||||
|
def test_ast2func(self):
|
||||||
|
def func(x, y):
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
x, y = 10, 20
|
||||||
|
self.assertEqual(func(x, y), self._ast2func(func)(x, y))
|
||||||
|
|
||||||
|
def test_ast2func_dygraph(self):
|
||||||
|
func = dyfunc_with_if_else
|
||||||
|
x_data = np.random.random([10, 16]).astype('float32')
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
x_v = fluid.dygraph.to_variable(x_data)
|
||||||
|
true_ret = func(x_v).numpy()
|
||||||
|
test_ret = self._ast2func(func)(x_v).numpy()
|
||||||
|
self.assertTrue((true_ret == test_ret).all())
|
||||||
|
|
||||||
|
def test_ast2func_static(self):
|
||||||
|
def func(x):
|
||||||
|
y = fluid.layers.relu(x)
|
||||||
|
loss = fluid.layers.mean(y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
x_data = np.random.random([10, 16]).astype('float32')
|
||||||
|
main_program = fluid.Program()
|
||||||
|
with fluid.program_guard(main_program):
|
||||||
|
x_v = fluid.layers.assign(x_data)
|
||||||
|
true_ret = func(x_v)
|
||||||
|
test_ret = self._ast2func(func)(x_v)
|
||||||
|
exe = fluid.Executor(fluid.CPUPlace())
|
||||||
|
ret = exe.run(main_program, fetch_list=[true_ret, test_ret])
|
||||||
|
self.assertTrue((ret[0] == ret[1]).all())
|
||||||
|
|
||||||
|
def test_ast2func_error(self):
|
||||||
|
with self.assertRaises(Exception) as e:
|
||||||
|
self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo'))
|
||||||
|
self.assertTrue("Type of ast_root should be gast.AST or ast.AST" in
|
||||||
|
str(e.exception))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue