add clean checkpoint

shanyi15-patch-3
tangwei12 7 years ago
parent 192f9a5a70
commit cf3fb2488c

@ -529,6 +529,19 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
filename=None)
def clean_checkpoint(checkpoint_dir, delete_dir=False):
"""
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
"""
if checkpoint_dir is None:
checkpoint_dir = os.getcwd()
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir)
def _get_serial_dir(serial, checkpoint_dir):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
return os.path.join(checkpoint_dir, serial_folder)

Loading…
Cancel
Save