|
|
|
@ -478,9 +478,10 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param trainer_id
|
|
|
|
|
:param is_chief
|
|
|
|
|
:param main_program
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
:param is_chief
|
|
|
|
|
"""
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
@ -502,6 +503,11 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def need_load_checkpoint(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
If the directory have checkpoint files, it will return lastest checkpoint directory serial number
|
|
|
|
|
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
"""
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return None
|
|
|
|
@ -515,6 +521,7 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param serial
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -536,7 +543,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
"""
|
|
|
|
|
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param delete_dir
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("The values of 'checkpoint_dir' should not be None")
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints=0)
|
|
|
|
@ -549,6 +560,11 @@ def load_persist_vars_without_grad(executor, dirname, program, nest=True):
|
|
|
|
|
"""
|
|
|
|
|
load_persist_vars_without_grad will load variables from a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be loaded.
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param dirname
|
|
|
|
|
:param program
|
|
|
|
|
:param nest
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if nest:
|
|
|
|
@ -566,6 +582,10 @@ def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
save_persist_vars_without_grad will save variables to a directory by an executor,
|
|
|
|
|
the variable named end with "@GRAD" will not be saved.
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param dirname
|
|
|
|
|
:param program
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_model_dir(dirname)
|
|
|
|
|
save_vars(
|
|
|
|
|