|
|
|
@ -489,9 +489,9 @@ CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
main_program=None):
|
|
|
|
|
checkpoint_dir,
|
|
|
|
|
main_program=None,
|
|
|
|
|
max_num_checkpoints=3):
|
|
|
|
|
"""
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
|
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
|
|
|
|
@ -500,12 +500,11 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
:param save_interval_secs
|
|
|
|
|
:param main_program
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
@ -518,7 +517,7 @@ def save_checkpoint(executor,
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
def load_checkpoint(executor, checkpoint_dir, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Load checkpoint from a directory by executor,
|
|
|
|
|
it will find the most recent saved checkpoint file and load it auto.
|
|
|
|
@ -529,7 +528,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
@ -546,7 +545,7 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
|
|
|
|
|
|
|
|
|
|
if delete_dir and not os.listdir(checkpoint_dir):
|
|
|
|
|