|
|
|
@ -14,11 +14,12 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
__all__ = ['TracedLayer', 'dygraph_to_static_output']
|
|
|
|
|
__all__ = ['TracedLayer', 'dygraph_to_static_output', 'dygraph_to_static_graph']
|
|
|
|
|
|
|
|
|
|
import gast
|
|
|
|
|
import inspect
|
|
|
|
|
import textwrap
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
from ..wrapped_decorator import wrap_decorator
|
|
|
|
|
from .base import program_desc_tracing_guard, switch_to_static_graph
|
|
|
|
@ -29,7 +30,7 @@ from paddle.fluid import core
|
|
|
|
|
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
|
|
|
|
|
from paddle.fluid.executor import Executor, scope_guard
|
|
|
|
|
from paddle.fluid.compiler import CompiledProgram
|
|
|
|
|
from paddle.fluid import program_guard, data
|
|
|
|
|
from paddle.fluid import program_guard, data, default_startup_program, default_main_program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_program_from_desc(program_desc):
|
|
|
|
@ -55,43 +56,60 @@ def extract_vars(inputs):
|
|
|
|
|
return result_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dygraph_to_static_output_(dygraph_func):
|
|
|
|
|
def __impl__(*args, **kwargs):
|
|
|
|
|
def to_static_func(dygraph_func):
|
|
|
|
|
# Get AST from dygraph function
|
|
|
|
|
dygraph_code = inspect.getsource(dygraph_func)
|
|
|
|
|
dygraph_code = textwrap.dedent(dygraph_code)
|
|
|
|
|
root = gast.parse(dygraph_code)
|
|
|
|
|
|
|
|
|
|
# Get AST from dygraph function
|
|
|
|
|
dygraph_code = inspect.getsource(dygraph_func)
|
|
|
|
|
dygraph_code = textwrap.dedent(dygraph_code)
|
|
|
|
|
root = gast.parse(dygraph_code)
|
|
|
|
|
# Transform AST
|
|
|
|
|
dygraph_to_static = DygraphToStaticAst()
|
|
|
|
|
root_wrapper = dygraph_to_static.get_static_ast(root)
|
|
|
|
|
|
|
|
|
|
# Transform AST
|
|
|
|
|
dygraph_to_static = DygraphToStaticAst()
|
|
|
|
|
root_wrapper = dygraph_to_static.get_static_ast(root)
|
|
|
|
|
# Get static_func from AST
|
|
|
|
|
func_name = dygraph_to_static.get_module_name()
|
|
|
|
|
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
|
|
|
|
|
|
|
|
|
|
# Get static_func from AST
|
|
|
|
|
func_name = dygraph_to_static.get_module_name()
|
|
|
|
|
static_func, file_name = ast_to_func(root_wrapper.node, func_name)
|
|
|
|
|
return static_func, dygraph_to_static
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dygraph_to_static_graph_(dygraph_func):
|
|
|
|
|
def __impl__(*args, **kwargs):
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
|
|
|
|
|
" Please use it in static mode.")
|
|
|
|
|
return dygraph_func(*args, **kwargs)
|
|
|
|
|
static_func, dygraph_to_static = to_static_func(dygraph_func)
|
|
|
|
|
return static_func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return __impl__
|
|
|
|
|
|
|
|
|
|
if not in_dygraph_mode():
|
|
|
|
|
return static_func(*args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
|
|
|
|
|
feed_dict = {}
|
|
|
|
|
for feed_name, idx in feed_name_to_idx.items():
|
|
|
|
|
feed_dict[feed_name] = args[idx]
|
|
|
|
|
|
|
|
|
|
# Run static_func in static mode
|
|
|
|
|
startup_program = Program()
|
|
|
|
|
main_program = Program()
|
|
|
|
|
static_res = run_static_func(main_program, startup_program,
|
|
|
|
|
static_func, args, kwargs, feed_dict,
|
|
|
|
|
feed_name_to_idx)
|
|
|
|
|
|
|
|
|
|
def _dygraph_to_static_output_(dygraph_func):
|
|
|
|
|
def __impl__(*args, **kwargs):
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode."
|
|
|
|
|
" Please use it in static mode.")
|
|
|
|
|
return dygraph_func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
static_func, dygraph_to_static = to_static_func(dygraph_func)
|
|
|
|
|
feed_name_to_idx = dygraph_to_static.get_feed_name_to_idx()
|
|
|
|
|
feed_dict = {}
|
|
|
|
|
for feed_name, idx in feed_name_to_idx.items():
|
|
|
|
|
feed_dict[feed_name] = args[idx]
|
|
|
|
|
|
|
|
|
|
# Run static_func in static mode
|
|
|
|
|
startup_program = default_main_program()
|
|
|
|
|
main_program = default_startup_program()
|
|
|
|
|
static_res = run_static_func(main_program, startup_program, static_func,
|
|
|
|
|
args, kwargs, feed_dict, feed_name_to_idx)
|
|
|
|
|
return static_res
|
|
|
|
|
|
|
|
|
|
return __impl__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def run_static_func(main_program, startup_program, static_func, args, kwargs,
|
|
|
|
|
feed_dict, feed_name_to_idx):
|
|
|
|
|
|
|
|
|
@ -114,6 +132,7 @@ def run_static_func(main_program, startup_program, static_func, args, kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dygraph_to_static_output = wrap_decorator(_dygraph_to_static_output_)
|
|
|
|
|
dygraph_to_static_graph = wrap_decorator(_dygraph_to_static_graph_)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dygraph_only
|
|
|
|
|