[Dy2Stat] Removes temporary files created during the transformation of dygraph to static graph. (#26150)

revert-24895-update_cub
liym27 5 years ago committed by GitHub
parent 361363c321
commit 1d730ffbf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import traceback
@ -38,9 +39,27 @@ def attach_error_data(error, in_runtime=False):
setattr(error, ERROR_DATA, error_data)
remove_static_file()
return error
def remove_static_file():
"""
Removes temporary files created during the transformation of dygraph to static graph.
"""
del_files = set()
for loc in global_origin_info_map:
static_filepath = loc[0]
del_files.add(static_filepath)
filename, extension = os.path.splitext(static_filepath)
del_files.add(filename + ".pyc")
for filepath in del_files:
if os.path.exists(filepath):
os.remove(filepath)
class TraceBackFrame(OriginInfo):
"""
Traceback frame information.

@ -368,6 +368,11 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
TODO: If only decorate one of inner function instead of decorating the main
function, the other inner functions are invisible for the decorated function.
"""
def remove_file(filepath):
if os.path.exists(filepath):
os.remove(filepath)
source = ast_to_source_code(ast_root)
import_fluid = "import paddle.fluid as fluid\n"
source = import_fluid + source
@ -382,7 +387,9 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True):
f.write(source)
if delete_on_exit:
atexit.register(lambda: os.remove(f.name))
atexit.register(lambda: remove_file(f.name))
atexit.register(lambda: remove_file(f.name[:-3] + ".pyc"))
module = imp.load_source(module_name, f.name)
func_name = dyfunc.__name__
if not hasattr(module, func_name):

Loading…
Cancel
Save