|
|
@ -19,7 +19,6 @@ import astor
|
|
|
|
import atexit
|
|
|
|
import atexit
|
|
|
|
import copy
|
|
|
|
import copy
|
|
|
|
import gast
|
|
|
|
import gast
|
|
|
|
import imp
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
import inspect
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import six
|
|
|
|
import six
|
|
|
@ -28,6 +27,12 @@ import textwrap
|
|
|
|
|
|
|
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# imp is deprecated in python3
|
|
|
|
|
|
|
|
if six.PY2:
|
|
|
|
|
|
|
|
import imp
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
from importlib.machinery import SourceFileLoader
|
|
|
|
|
|
|
|
|
|
|
|
dygraph_class_to_static_api = {
|
|
|
|
dygraph_class_to_static_api = {
|
|
|
|
"CosineDecay": "cosine_decay",
|
|
|
|
"CosineDecay": "cosine_decay",
|
|
|
|
"ExponentialDecay": "exponential_decay",
|
|
|
|
"ExponentialDecay": "exponential_decay",
|
|
|
@ -391,7 +396,10 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
|
|
|
|
atexit.register(lambda: remove_if_exit(f.name))
|
|
|
|
atexit.register(lambda: remove_if_exit(f.name))
|
|
|
|
atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc"))
|
|
|
|
atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc"))
|
|
|
|
|
|
|
|
|
|
|
|
module = imp.load_source(module_name, f.name)
|
|
|
|
if six.PY2:
|
|
|
|
|
|
|
|
module = imp.load_source(module_name, f.name)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
module = SourceFileLoader(module_name, f.name).load_module()
|
|
|
|
func_name = dyfunc.__name__
|
|
|
|
func_name = dyfunc.__name__
|
|
|
|
if not hasattr(module, func_name):
|
|
|
|
if not hasattr(module, func_name):
|
|
|
|
raise ValueError(
|
|
|
|
raise ValueError(
|
|
|
|