|
|
|
@ -246,7 +246,7 @@ class StaticLayer(object):
|
|
|
|
|
self._function_spec = FunctionSpec(function, input_spec)
|
|
|
|
|
self._program_cache = ProgramCache()
|
|
|
|
|
self._descriptor_cache = weakref.WeakKeyDictionary()
|
|
|
|
|
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
|
|
|
|
|
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
|
|
|
|
|
self._program_trans = ProgramTranslator()
|
|
|
|
|
|
|
|
|
|
def __get__(self, instance, owner):
|
|
|
|
@ -299,16 +299,17 @@ class StaticLayer(object):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 1. call dygraph function directly if not enable `declarative`
|
|
|
|
|
if not self._program_trans.enable_declarative:
|
|
|
|
|
if not self._program_trans.enable_to_static:
|
|
|
|
|
logging_utils.warn(
|
|
|
|
|
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
|
|
|
|
|
"We will just return dygraph output.")
|
|
|
|
|
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
|
|
|
|
|
"We will just return dygraph output. If you would like to get static graph output, please call API "
|
|
|
|
|
"ProgramTranslator.enable(True)")
|
|
|
|
|
return self._call_dygraph_function(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
if not in_dygraph_mode() and self._program_trans.enable_declarative:
|
|
|
|
|
if not in_dygraph_mode():
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
|
|
|
|
|
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
|
|
|
|
|
"because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
|
|
|
|
|
"following API: paddle.disable_static().".format(
|
|
|
|
|
self.dygraph_function))
|
|
|
|
|
|
|
|
|
@ -723,15 +724,15 @@ class ProgramTranslator(object):
|
|
|
|
|
return
|
|
|
|
|
self._initialized = True
|
|
|
|
|
self._program_cache = ProgramCache()
|
|
|
|
|
self.enable_declarative = True
|
|
|
|
|
self.enable_to_static = True
|
|
|
|
|
|
|
|
|
|
def enable(self, enable_declarative):
|
|
|
|
|
def enable(self, enable_to_static):
|
|
|
|
|
"""
|
|
|
|
|
Enable or disable the converting from imperative to declarative by
|
|
|
|
|
ProgramTranslator globally.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
enable_declarative (bool): True or False to enable or disable declarative.
|
|
|
|
|
enable_to_static (bool): True or False to enable or disable declarative.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None.
|
|
|
|
@ -760,9 +761,9 @@ class ProgramTranslator(object):
|
|
|
|
|
print(func(x).numpy()) # [[2. 2.]]
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
check_type(enable_declarative, "enable_declarative", bool,
|
|
|
|
|
check_type(enable_to_static, "enable_to_static", bool,
|
|
|
|
|
"ProgramTranslator.enable")
|
|
|
|
|
self.enable_declarative = enable_declarative
|
|
|
|
|
self.enable_to_static = enable_to_static
|
|
|
|
|
|
|
|
|
|
def get_output(self, dygraph_func, *args, **kwargs):
|
|
|
|
|
"""
|
|
|
|
@ -803,10 +804,12 @@ class ProgramTranslator(object):
|
|
|
|
|
assert callable(
|
|
|
|
|
dygraph_func
|
|
|
|
|
), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
|
|
|
|
|
if not self.enable_declarative:
|
|
|
|
|
if not self.enable_to_static:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
|
|
|
|
|
"We will just return dygraph output.")
|
|
|
|
|
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
|
|
|
|
|
"We will just return dygraph output. "
|
|
|
|
|
"Please call ProgramTranslator.enable(True) if you would like to get static output."
|
|
|
|
|
)
|
|
|
|
|
return dygraph_func(*args, **kwargs)
|
|
|
|
|
try:
|
|
|
|
|
function_spec = FunctionSpec(dygraph_func)
|
|
|
|
@ -876,10 +879,11 @@ class ProgramTranslator(object):
|
|
|
|
|
assert callable(
|
|
|
|
|
dygraph_func
|
|
|
|
|
), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
|
|
|
|
|
if not self.enable_declarative:
|
|
|
|
|
if not self.enable_to_static:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
|
|
|
|
|
"just return dygraph output.")
|
|
|
|
|
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
|
|
|
|
|
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
|
|
|
|
|
)
|
|
|
|
|
return dygraph_func
|
|
|
|
|
|
|
|
|
|
static_func = convert_to_static(dygraph_func)
|
|
|
|
@ -929,10 +933,12 @@ class ProgramTranslator(object):
|
|
|
|
|
assert callable(
|
|
|
|
|
dygraph_func
|
|
|
|
|
), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
|
|
|
|
|
if not self.enable_declarative:
|
|
|
|
|
if not self.enable_to_static:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
|
|
|
|
|
"We will just return dygraph output.")
|
|
|
|
|
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
|
|
|
|
|
"We will just return dygraph output. "
|
|
|
|
|
"Please call ProgramTranslator.enable(True) if you would like to get static output."
|
|
|
|
|
)
|
|
|
|
|
return dygraph_func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
function_spec = FunctionSpec(dygraph_func)
|
|
|
|
|