|
|
|
@ -529,6 +529,19 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
"""
|
|
|
|
|
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
|
|
|
|
|
|
|
|
|
|
if delete_dir and not os.listdir(checkpoint_dir):
|
|
|
|
|
os.rmdir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_serial_dir(serial, checkpoint_dir):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
return os.path.join(checkpoint_dir, serial_folder)
|
|
|
|
|