|
|
|
@ -299,6 +299,42 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids):
|
|
|
|
|
return func_def_node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImportVisitor(gast.NodeVisitor):
|
|
|
|
|
"""
|
|
|
|
|
Visitor to parse all `import` statement.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, file_name):
|
|
|
|
|
self.root = self.file_to_ast(file_name)
|
|
|
|
|
self.import_statements = []
|
|
|
|
|
|
|
|
|
|
def transform(self):
|
|
|
|
|
if self.root is not None:
|
|
|
|
|
self.visit(self.root)
|
|
|
|
|
self.after_visit()
|
|
|
|
|
return self.import_statements
|
|
|
|
|
|
|
|
|
|
def visit_Import(self, node):
|
|
|
|
|
self.import_statements.append(ast_to_source_code(node))
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def visit_ImportFrom(self, node):
|
|
|
|
|
self.import_statements.append(ast_to_source_code(node))
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
def after_visit(self):
|
|
|
|
|
essential_statements = ["import paddle.fluid as fluid\n"]
|
|
|
|
|
new_stmts = set(essential_statements) - set(self.import_statements)
|
|
|
|
|
self.import_statements.extend(list(new_stmts))
|
|
|
|
|
|
|
|
|
|
def file_to_ast(self, file_name):
|
|
|
|
|
root = None
|
|
|
|
|
if file_name is not None:
|
|
|
|
|
with open(file_name) as f:
|
|
|
|
|
root = gast.parse(f.read())
|
|
|
|
|
return root
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def index_in_list(array_list, item):
|
|
|
|
|
try:
|
|
|
|
|
return array_list.index(item)
|
|
|
|
@ -307,7 +343,7 @@ def index_in_list(array_list, item):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ast_to_func(ast_root, func_name, delete_on_exit=True):
|
|
|
|
|
def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
|
|
|
|
|
"""
|
|
|
|
|
Transform modified AST of decorated function into python callable object.
|
|
|
|
|
"""
|
|
|
|
@ -318,13 +354,13 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
|
|
|
|
|
else:
|
|
|
|
|
f = tempfile.NamedTemporaryFile(
|
|
|
|
|
mode='w', suffix='.py', delete=False, encoding='utf-8')
|
|
|
|
|
|
|
|
|
|
# TODO(Aurelius84): more elegant way to transform ast into callable object
|
|
|
|
|
import_str = "import paddle\n" \
|
|
|
|
|
"import paddle.fluid as fluid\n" \
|
|
|
|
|
"import paddle.fluid.layers as layers\n" \
|
|
|
|
|
"import numpy as np\n" \
|
|
|
|
|
"import numpy\n"
|
|
|
|
|
# `sys.modules` is used to cache all modules and packages that avoids
|
|
|
|
|
# to import same modules twice by the import mechanism in python.
|
|
|
|
|
# We insert the import statements defined in source file into the tmpfile
|
|
|
|
|
# to make it easier to import external functions correctly.
|
|
|
|
|
source_file = inspect.getfile(dyfunc)
|
|
|
|
|
import_statements = ImportVisitor(source_file).transform()
|
|
|
|
|
import_str = "".join(import_statements)
|
|
|
|
|
with f:
|
|
|
|
|
module_name = os.path.basename(f.name[:-3])
|
|
|
|
|
f.write(import_str)
|
|
|
|
@ -333,6 +369,7 @@ def ast_to_func(ast_root, func_name, delete_on_exit=True):
|
|
|
|
|
if delete_on_exit:
|
|
|
|
|
atexit.register(lambda: os.remove(f.name))
|
|
|
|
|
module = imp.load_source(module_name, f.name)
|
|
|
|
|
func_name = dyfunc.__name__
|
|
|
|
|
if not hasattr(module, func_name):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'Function: %s doesn\'t exist in the Module transformed from AST.' %
|
|
|
|
|