|
|
|
@ -28,15 +28,16 @@ from paddle.fluid.dygraph import layers
|
|
|
|
|
from paddle.fluid.layers.utils import flatten
|
|
|
|
|
from paddle.fluid.layers.utils import pack_sequence_as
|
|
|
|
|
from paddle.fluid.dygraph.base import switch_to_static_graph
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
|
|
|
|
|
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
from paddle.fluid.dygraph.base import param_guard
|
|
|
|
|
from paddle.fluid.data_feeder import check_type
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
|
|
|
|
|
|
|
|
|
|
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
|
|
|
|
|
__all__ = ['ProgramTranslator', 'convert_to_static']
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger("fluid")
|
|
|
|
|
|
|
|
|
@ -47,43 +48,76 @@ class FunctionCache(object):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._dycode_to_static_func = dict()
|
|
|
|
|
self._static_func_to_transformer = dict()
|
|
|
|
|
# Caches the converted static functions. {dygraph_func: static_func}
|
|
|
|
|
self._converted_static_func_caches = dict()
|
|
|
|
|
# Caches the converted ast node for same source code. {source_code: ast_root}
|
|
|
|
|
self._code_to_ast_caches = dict()
|
|
|
|
|
self._dygraph_to_static = DygraphToStaticAst()
|
|
|
|
|
|
|
|
|
|
def get_or_cache_func(self, func):
|
|
|
|
|
# code = self._get_dedent_code_string(func)
|
|
|
|
|
static_func = self._dycode_to_static_func.get(func, None)
|
|
|
|
|
def convert_with_cache(self, func):
|
|
|
|
|
"""
|
|
|
|
|
Returns the cached static function or converts it when first encounters the function.
|
|
|
|
|
"""
|
|
|
|
|
# If hit cache, return it directly.
|
|
|
|
|
static_func = self._converted_static_func_caches.get(func, None)
|
|
|
|
|
|
|
|
|
|
if static_func is None:
|
|
|
|
|
static_func, dygraph_to_static_transformer = convert_to_static(func)
|
|
|
|
|
self._dycode_to_static_func[func] = static_func
|
|
|
|
|
self._static_func_to_transformer[
|
|
|
|
|
func] = dygraph_to_static_transformer
|
|
|
|
|
static_func = self._convert(func)
|
|
|
|
|
self._converted_static_func_caches[func] = static_func
|
|
|
|
|
|
|
|
|
|
return static_func
|
|
|
|
|
|
|
|
|
|
def get_transformer(self, func):
|
|
|
|
|
return self._static_func_to_transformer.get(func, None)
|
|
|
|
|
def _convert(self, func):
|
|
|
|
|
"""
|
|
|
|
|
Converts dygraph function into static function. For two functions with same dedent code,
|
|
|
|
|
the second function will reuse the transformed ast node of previous one.
|
|
|
|
|
|
|
|
|
|
For example:
|
|
|
|
|
# A.py
|
|
|
|
|
def foo(x, y):
|
|
|
|
|
z = x + y
|
|
|
|
|
return z
|
|
|
|
|
|
|
|
|
|
# B.py
|
|
|
|
|
def foo(x, y):
|
|
|
|
|
z = x + y
|
|
|
|
|
return z
|
|
|
|
|
|
|
|
|
|
If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo
|
|
|
|
|
to speed up the conversion.
|
|
|
|
|
"""
|
|
|
|
|
# Note: In Python2, it will raise OSError when inspect function
|
|
|
|
|
# with decorator directly and function.__wrapped__ holds the actual function.
|
|
|
|
|
func = getattr(func, '__wrapped__', func)
|
|
|
|
|
source_code = func_to_source_code(func)
|
|
|
|
|
if source_code in self._code_to_ast_caches:
|
|
|
|
|
root_wrapper = self._code_to_ast_caches[source_code]
|
|
|
|
|
else:
|
|
|
|
|
root = gast.parse(source_code)
|
|
|
|
|
root_wrapper = self._dygraph_to_static.get_static_ast(root)
|
|
|
|
|
self._code_to_ast_caches[source_code] = root_wrapper
|
|
|
|
|
|
|
|
|
|
def _get_dedent_code_string(self, func):
|
|
|
|
|
raw_code = inspect.getsource(func)
|
|
|
|
|
dedent_code = textwrap.dedent(raw_code)
|
|
|
|
|
return dedent_code
|
|
|
|
|
# Get static function from AST
|
|
|
|
|
static_func, file_name = ast_to_func(root_wrapper.node, func)
|
|
|
|
|
return static_func
|
|
|
|
|
|
|
|
|
|
def exist(self, func):
|
|
|
|
|
return self._dycode_to_static_func.get(func, None) is not None
|
|
|
|
|
return func in self._converted_static_func_caches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_CACHE_LOCK = threading.Lock()
|
|
|
|
|
_FUNCTION_CACHE = FunctionCache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_function_with_cache(dygraph_func):
|
|
|
|
|
def convert_to_static(function):
|
|
|
|
|
"""
|
|
|
|
|
Transforms function of dygraph into static function using the cache mechanism.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
function(callable): The function with dygraph layers that will be converted into static layers.
|
|
|
|
|
"""
|
|
|
|
|
with _CACHE_LOCK:
|
|
|
|
|
static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
|
|
|
|
|
static_func = _FUNCTION_CACHE.convert_with_cache(function)
|
|
|
|
|
return static_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -202,7 +236,7 @@ class ConcreteProgram(object):
|
|
|
|
|
"""
|
|
|
|
|
# Transforms dygraph function into static function and caches it.
|
|
|
|
|
dygraph_function = func_spec.dyfunc
|
|
|
|
|
static_func = convert_function_with_cache(dygraph_function)
|
|
|
|
|
static_func = convert_to_static(dygraph_function)
|
|
|
|
|
|
|
|
|
|
main_program, startup_program = framework.Program(), framework.Program()
|
|
|
|
|
# Note: The random seed should be synchronized into cached program
|
|
|
|
@ -461,7 +495,7 @@ class ProgramTranslator(object):
|
|
|
|
|
"just return dygraph output.")
|
|
|
|
|
return dygraph_func
|
|
|
|
|
|
|
|
|
|
static_func = convert_function_with_cache(dygraph_func)
|
|
|
|
|
static_func = convert_to_static(dygraph_func)
|
|
|
|
|
return static_func
|
|
|
|
|
|
|
|
|
|
def get_program(self, dygraph_func, *args, **kwargs):
|
|
|
|
|