[dy2static]Support recursive call (#23900)
* [Dy2Stat]Support recursive call. * Remove Redundant decorator to pass the Py35 unittest temporarily.revert-22778-infer_var_type
parent
ad7ac4c607
commit
0b0adbf9b6
@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import gast
|
||||
|
||||
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
||||
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
|
||||
|
||||
|
||||
class CallTransformer(gast.NodeTransformer):
|
||||
"""
|
||||
This class transforms function calls into Static Graph Ast.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapper_root):
|
||||
assert isinstance(
|
||||
wrapper_root, AstNodeWrapper
|
||||
), "Input non-AstNodeWrapper node for the initialization of CallTransformer."
|
||||
self.wrapper_root = wrapper_root
|
||||
self.root = wrapper_root.node
|
||||
|
||||
def transform(self):
|
||||
self.visit(self.root)
|
||||
|
||||
def visit_Call(self, node):
|
||||
self.generic_visit(node)
|
||||
if is_paddle_api(node):
|
||||
return node
|
||||
func_str = ast_to_source_code(node.func).strip()
|
||||
new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
|
||||
func_str)
|
||||
new_func_ast = gast.parse(new_func_str).body[0].value
|
||||
node.func = new_func_ast
|
||||
|
||||
return node
|
@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['convert_call']
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import pdb
|
||||
import re
|
||||
import types
|
||||
|
||||
import numpy
|
||||
import six
|
||||
|
||||
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
|
||||
from paddle.fluid.dygraph.layers import Layer
|
||||
|
||||
program_translator = ProgramTranslator()
|
||||
to_static_func = program_translator.get_func
|
||||
|
||||
|
||||
def is_builtin(func):
|
||||
if isinstance(func, types.BuiltinFunctionType):
|
||||
return True
|
||||
elif func in six.moves.builtins.__dict__.values():
|
||||
return True
|
||||
# Other built-in modules
|
||||
# TODO(liym27): A better way to do this.
|
||||
elif any(func in m.__dict__.values()
|
||||
for m in (collections, pdb, copy, inspect, re, six, numpy)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_paddle_func(func):
|
||||
m = inspect.getmodule(func)
|
||||
return m is not None and m.__name__.startswith("paddle")
|
||||
|
||||
|
||||
def convert_call(func):
|
||||
"""
|
||||
Converts a function call which needs to be transformed to static fucntion.
|
||||
|
||||
Args:
|
||||
func (callable): A callable function or method to convert.
|
||||
|
||||
Returns:
|
||||
Callable: A converted function.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.dygraph.dygraph_to_static import convert_call
|
||||
|
||||
def dyfunc(x):
|
||||
if fluid.layers.mean(x) < 0:
|
||||
x_v = x - 1
|
||||
else:
|
||||
x_v = x + 1
|
||||
|
||||
return x_v
|
||||
new_func = convert_call(dyfunc)
|
||||
x = fluid.layers.fill_constant(shape=[3, 3], value=0, dtype='float64')
|
||||
x_v = new_func(x)
|
||||
exe = fluid.Executor(fluid.CPUPlace())
|
||||
out = exe.run(fetch_list=[x_v])
|
||||
print(out[0])
|
||||
# [[1. 1. 1.]
|
||||
# [1. 1. 1.]
|
||||
# [1. 1. 1.]]
|
||||
|
||||
"""
|
||||
func_self = None
|
||||
converted_call = None
|
||||
|
||||
if is_builtin(func):
|
||||
return func
|
||||
|
||||
if is_paddle_func(func):
|
||||
return func
|
||||
|
||||
if inspect.isfunction(func):
|
||||
# TODO(liym27): If func is a lambda function, special conversion is needed.
|
||||
if func.__name__ == '<lambda>':
|
||||
return func
|
||||
try:
|
||||
if func in func.__globals__.values():
|
||||
converted_call = to_static_func(func)
|
||||
func_self = getattr(func, '__self__', None)
|
||||
except AttributeError:
|
||||
# NOTE:
|
||||
# If func is not in __globals__, it does not need to be transformed
|
||||
# because it has been transformed before.
|
||||
converted_call = None
|
||||
except (IOError, OSError):
|
||||
# NOTE:
|
||||
# If func has beed decorated, its source code can not be get
|
||||
# so that it can not be transformed to static function.
|
||||
converted_call = None
|
||||
elif inspect.ismethod(func):
|
||||
try:
|
||||
func_self = getattr(func, '__self__', None)
|
||||
converted_call = to_static_func(func)
|
||||
except (IOError, OSError):
|
||||
# NOTE: func may have beed decorated.
|
||||
converted_call = None
|
||||
|
||||
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
|
||||
if hasattr(func, 'forward') and isinstance(func, Layer):
|
||||
try:
|
||||
forward_func = to_static_func(func.forward)
|
||||
setattr(func, 'forward', forward_func)
|
||||
func_self = func
|
||||
except Exception:
|
||||
# NOTE: func.forward may have beed decorated.
|
||||
func_self = None if func_self else func_self
|
||||
converted_call = func
|
||||
else:
|
||||
try:
|
||||
call_func = func.__class__.__call__
|
||||
converted_call = to_static_func(call_func)
|
||||
func_self = func
|
||||
except Exception:
|
||||
# NOTE:
|
||||
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
|
||||
# it doesn't need to be transformed
|
||||
func_self = None if func_self else func_self
|
||||
|
||||
if converted_call is None:
|
||||
return func
|
||||
|
||||
if func_self:
|
||||
converted_call = functools.partial(converted_call, func_self)
|
||||
|
||||
return converted_call
|
@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.dygraph.jit import dygraph_to_static_func
|
||||
|
||||
SEED = 2020
|
||||
np.random.seed(SEED)
|
||||
|
||||
|
||||
def dyfunc_with_if(x_v):
|
||||
if fluid.layers.mean(x_v).numpy()[0] > 5:
|
||||
x_v = x_v - 1
|
||||
else:
|
||||
x_v = x_v + 1
|
||||
return x_v
|
||||
|
||||
|
||||
@dygraph_to_static_func
|
||||
def nested_func(x_v):
|
||||
x_v = fluid.dygraph.to_variable(x_v)
|
||||
|
||||
def fn1():
|
||||
return x_v
|
||||
|
||||
res = fn1()
|
||||
return res
|
||||
|
||||
|
||||
class TestRecursiveCall1(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.input = np.random.random([10, 16]).astype('float32')
|
||||
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
self.init_test_func()
|
||||
|
||||
def init_test_func(self):
|
||||
self.dyfunc = nested_func
|
||||
|
||||
def get_dygraph_output(self):
|
||||
with fluid.dygraph.guard():
|
||||
res = self.dyfunc(self.input).numpy()
|
||||
return res
|
||||
|
||||
def get_static_output(self):
|
||||
main_program = fluid.Program()
|
||||
with fluid.program_guard(main_program):
|
||||
static_out = self.dyfunc(self.input)
|
||||
exe = fluid.Executor(self.place)
|
||||
static_res = exe.run(main_program, fetch_list=static_out)
|
||||
return static_res[0]
|
||||
|
||||
def test_transformed_static_result(self):
|
||||
static_res = self.get_static_output()
|
||||
dygraph_res = self.get_dygraph_output()
|
||||
self.assertTrue(
|
||||
np.allclose(dygraph_res, static_res),
|
||||
msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res,
|
||||
static_res))
|
||||
|
||||
|
||||
lambda_fun = lambda x: x
|
||||
|
||||
|
||||
class MyConvLayer(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(MyConvLayer, self).__init__()
|
||||
self._conv = fluid.dygraph.Conv2D(
|
||||
num_channels=3,
|
||||
num_filters=2,
|
||||
filter_size=3,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.99)),
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.5)))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = dyfunc_with_if(inputs)
|
||||
y = lambda_fun(y)
|
||||
y = self.dymethod(y)
|
||||
return y
|
||||
|
||||
@dygraph_to_static_func
|
||||
def dymethod(self, x_v):
|
||||
return x_v
|
||||
|
||||
|
||||
class MyLayer(fluid.dygraph.Layer):
|
||||
def __init__(self):
|
||||
super(MyLayer, self).__init__()
|
||||
|
||||
self.conv = MyConvLayer()
|
||||
self.fc = fluid.dygraph.Linear(
|
||||
input_dim=5,
|
||||
output_dim=1,
|
||||
act='relu',
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.99)),
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(value=0.5)))
|
||||
|
||||
@dygraph_to_static_func
|
||||
def forward(self, inputs):
|
||||
h = self.conv(inputs)
|
||||
out = self.fc(h)
|
||||
return out
|
||||
|
||||
|
||||
class TestRecursiveCall2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.input = np.random.random((1, 3, 3, 5)).astype('float32')
|
||||
self.Layer = MyLayer
|
||||
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
|
||||
def get_dygraph_output(self):
|
||||
with fluid.dygraph.guard():
|
||||
self.dygraph_func = self.Layer()
|
||||
fluid.default_startup_program.random_seed = SEED
|
||||
fluid.default_main_program.random_seed = SEED
|
||||
data = fluid.dygraph.to_variable(self.input)
|
||||
res = self.dygraph_func(data)
|
||||
|
||||
return res.numpy()
|
||||
|
||||
def get_static_output(self):
|
||||
startup_program = fluid.Program()
|
||||
startup_program.random_seed = SEED
|
||||
main_program = fluid.Program()
|
||||
main_program.random_seed = SEED
|
||||
|
||||
with fluid.program_guard(main_program, startup_program):
|
||||
self.dygraph_func = self.Layer()
|
||||
data = fluid.layers.assign(self.input)
|
||||
static_out = self.dygraph_func(data)
|
||||
|
||||
exe = fluid.Executor(self.place)
|
||||
exe.run(startup_program)
|
||||
static_res = exe.run(main_program, fetch_list=static_out)
|
||||
return static_res[0]
|
||||
|
||||
def test_transformed_static_result(self):
|
||||
dygraph_res = self.get_dygraph_output()
|
||||
static_res = self.get_static_output()
|
||||
self.assertTrue(
|
||||
np.allclose(dygraph_res, static_res),
|
||||
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_res,
|
||||
static_res))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue