|
|
|
@ -31,6 +31,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
|
|
|
|
|
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
from paddle.fluid.dygraph.base import param_guard
|
|
|
|
|
from paddle.fluid.data_feeder import check_type
|
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
|
|
|
|
@ -155,6 +156,28 @@ class FunctionSpec(object):
|
|
|
|
|
return self.__key() == self.__key()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Flag that indicates whether running code under `@declarative`
|
|
|
|
|
_in_declarative_mode_ = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def in_declarative_mode():
|
|
|
|
|
"""
|
|
|
|
|
Return a bool value that indicates whether running code under `@declarative`
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return _in_declarative_mode_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@signature_safe_contextmanager
|
|
|
|
|
def _switch_declarative_mode_guard_(is_declarative=True):
|
|
|
|
|
|
|
|
|
|
global _in_declarative_mode_
|
|
|
|
|
original_val = _in_declarative_mode_
|
|
|
|
|
_in_declarative_mode_ = is_declarative
|
|
|
|
|
yield
|
|
|
|
|
_in_declarative_mode_ = original_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConcreteProgram(object):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
inputs,
|
|
|
|
@ -190,17 +213,18 @@ class ConcreteProgram(object):
|
|
|
|
|
).random_seed
|
|
|
|
|
|
|
|
|
|
with framework.program_guard(main_program, startup_program):
|
|
|
|
|
# 1. Adds `fluid.data` layers for input if needed
|
|
|
|
|
inputs = func_spec.to_static_inputs(main_program)
|
|
|
|
|
|
|
|
|
|
# 2. Gets all ParamBases in the function
|
|
|
|
|
all_parameters = list(func_spec.parameters().values())
|
|
|
|
|
|
|
|
|
|
# 3. Builds program only once and returns the output Variables.
|
|
|
|
|
with param_guard(func_spec.parameters(False)):
|
|
|
|
|
outputs = static_func(*inputs)
|
|
|
|
|
if not isinstance(outputs, (tuple, list)):
|
|
|
|
|
outputs = [outputs] if outputs else []
|
|
|
|
|
with _switch_declarative_mode_guard_(is_declarative=True):
|
|
|
|
|
# 1. Adds `fluid.data` layers for input if needed
|
|
|
|
|
inputs = func_spec.to_static_inputs(main_program)
|
|
|
|
|
|
|
|
|
|
# 2. Gets all ParamBases in the function
|
|
|
|
|
all_parameters = list(func_spec.parameters().values())
|
|
|
|
|
|
|
|
|
|
# 3. Builds program only once and returns the output Variables.
|
|
|
|
|
with param_guard(func_spec.parameters(False)):
|
|
|
|
|
outputs = static_func(*inputs)
|
|
|
|
|
if not isinstance(outputs, (tuple, list)):
|
|
|
|
|
outputs = [outputs] if outputs else []
|
|
|
|
|
|
|
|
|
|
return ConcreteProgram(
|
|
|
|
|
inputs=inputs,
|
|
|
|
|