[Dy2Stat] Removes temporary files created during the transformation of dygraph to static graph. (#26150)

revert-24895-update_cub
liym27 5 years ago committed by GitHub
parent 361363c321
commit 1d730ffbf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys import sys
import traceback import traceback
@ -38,9 +39,27 @@ def attach_error_data(error, in_runtime=False):
setattr(error, ERROR_DATA, error_data) setattr(error, ERROR_DATA, error_data)
remove_static_file()
return error return error
def remove_static_file():
"""
Removes temporary files created during the transformation of dygraph to static graph.
"""
del_files = set()
for loc in global_origin_info_map:
static_filepath = loc[0]
del_files.add(static_filepath)
filename, extension = os.path.splitext(static_filepath)
del_files.add(filename + ".pyc")
for filepath in del_files:
if os.path.exists(filepath):
os.remove(filepath)
class TraceBackFrame(OriginInfo): class TraceBackFrame(OriginInfo):
""" """
Traceback frame information. Traceback frame information.

@ -368,6 +368,11 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
TODO: If only decorate one of inner function instead of decorating the main TODO: If only decorate one of inner function instead of decorating the main
function, the other inner functions are invisible for the decorated function. function, the other inner functions are invisible for the decorated function.
""" """
def remove_file(filepath):
if os.path.exists(filepath):
os.remove(filepath)
source = ast_to_source_code(ast_root) source = ast_to_source_code(ast_root)
import_fluid = "import paddle.fluid as fluid\n" import_fluid = "import paddle.fluid as fluid\n"
source = import_fluid + source source = import_fluid + source
@ -382,7 +387,9 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
f.write(source) f.write(source)
if delete_on_exit: if delete_on_exit:
atexit.register(lambda: os.remove(f.name)) atexit.register(lambda: remove_file(f.name))
atexit.register(lambda: remove_file(f.name[:-3] + ".pyc"))
module = imp.load_source(module_name, f.name) module = imp.load_source(module_name, f.name)
func_name = dyfunc.__name__ func_name = dyfunc.__name__
if not hasattr(module, func_name): if not hasattr(module, func_name):

Loading…
Cancel
Save