|
|
@ -30,6 +30,7 @@ import six
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
|
|
|
|
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
|
|
|
|
from paddle.fluid.dygraph.layers import Layer
|
|
|
|
from paddle.fluid.dygraph.layers import Layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
|
|
|
|
program_translator = ProgramTranslator()
|
|
|
|
program_translator = ProgramTranslator()
|
|
|
|
to_static_func = program_translator.get_func
|
|
|
|
to_static_func = program_translator.get_func
|
|
|
|
|
|
|
|
|
|
|
@ -102,8 +103,17 @@ def convert_call(func):
|
|
|
|
return func
|
|
|
|
return func
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if func in func.__globals__.values():
|
|
|
|
if func in func.__globals__.values():
|
|
|
|
converted_call = to_static_func(func)
|
|
|
|
if six.PY3:
|
|
|
|
func_self = getattr(func, '__self__', None)
|
|
|
|
source_code = inspect.getsource(func)
|
|
|
|
|
|
|
|
if any(decorator in source_code
|
|
|
|
|
|
|
|
for decorator in DECORATOR_NAMES):
|
|
|
|
|
|
|
|
converted_call = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
converted_call = to_static_func(func)
|
|
|
|
|
|
|
|
func_self = getattr(func, '__self__', None)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
converted_call = to_static_func(func)
|
|
|
|
|
|
|
|
func_self = getattr(func, '__self__', None)
|
|
|
|
except AttributeError:
|
|
|
|
except AttributeError:
|
|
|
|
# NOTE:
|
|
|
|
# NOTE:
|
|
|
|
# If func is not in __globals__, it does not need to be transformed
|
|
|
|
# If func is not in __globals__, it does not need to be transformed
|
|
|
@ -116,8 +126,17 @@ def convert_call(func):
|
|
|
|
converted_call = None
|
|
|
|
converted_call = None
|
|
|
|
elif inspect.ismethod(func):
|
|
|
|
elif inspect.ismethod(func):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
func_self = getattr(func, '__self__', None)
|
|
|
|
if six.PY3:
|
|
|
|
converted_call = to_static_func(func)
|
|
|
|
source_code = inspect.getsource(func)
|
|
|
|
|
|
|
|
if any(decorator in source_code
|
|
|
|
|
|
|
|
for decorator in DECORATOR_NAMES):
|
|
|
|
|
|
|
|
converted_call = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
converted_call = to_static_func(func)
|
|
|
|
|
|
|
|
func_self = getattr(func, '__self__', None)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
converted_call = to_static_func(func)
|
|
|
|
|
|
|
|
func_self = getattr(func, '__self__', None)
|
|
|
|
except (IOError, OSError):
|
|
|
|
except (IOError, OSError):
|
|
|
|
# NOTE: func may have beed decorated.
|
|
|
|
# NOTE: func may have beed decorated.
|
|
|
|
converted_call = None
|
|
|
|
converted_call = None
|
|
|
@ -125,9 +144,20 @@ def convert_call(func):
|
|
|
|
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
|
|
|
|
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
|
|
|
|
if hasattr(func, 'forward') and isinstance(func, Layer):
|
|
|
|
if hasattr(func, 'forward') and isinstance(func, Layer):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
forward_func = to_static_func(func.forward)
|
|
|
|
if six.PY3:
|
|
|
|
setattr(func, 'forward', forward_func)
|
|
|
|
source_code = inspect.getsource(func.forward)
|
|
|
|
func_self = func
|
|
|
|
if any(decorator in source_code
|
|
|
|
|
|
|
|
for decorator in DECORATOR_NAMES):
|
|
|
|
|
|
|
|
converted_call = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
forward_func = to_static_func(func.forward)
|
|
|
|
|
|
|
|
setattr(func, 'forward', forward_func)
|
|
|
|
|
|
|
|
func_self = func
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
forward_func = to_static_func(func.forward)
|
|
|
|
|
|
|
|
setattr(func, 'forward', forward_func)
|
|
|
|
|
|
|
|
func_self = func
|
|
|
|
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
except Exception:
|
|
|
|
# NOTE: func.forward may have beed decorated.
|
|
|
|
# NOTE: func.forward may have beed decorated.
|
|
|
|
func_self = None if func_self else func_self
|
|
|
|
func_self = None if func_self else func_self
|
|
|
@ -148,5 +178,4 @@ def convert_call(func):
|
|
|
|
|
|
|
|
|
|
|
|
if func_self:
|
|
|
|
if func_self:
|
|
|
|
converted_call = functools.partial(converted_call, func_self)
|
|
|
|
converted_call = functools.partial(converted_call, func_self)
|
|
|
|
|
|
|
|
|
|
|
|
return converted_call
|
|
|
|
return converted_call
|
|
|
|