|
|
|
@ -463,7 +463,10 @@ def save_checkpoint(executor,
|
|
|
|
|
save_interval_secs=600,
|
|
|
|
|
main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Save Variables to Checkpoint Directory
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
|
directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
|
|
|
|
|
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
|
|
|
|
|
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
@ -489,7 +492,7 @@ def save_checkpoint(executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=is_checkpoint_var,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
_lru_delete(dirname, max_num_checkpoints)
|
|
|
|
@ -497,10 +500,11 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(executor, dirname=None, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Load Variables from Checkpint Dir
|
|
|
|
|
Load checkpoint from directory by executor,
|
|
|
|
|
it will find lastest checkpoint file and load it auto.
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
:param executor
|
|
|
|
|
:param dirname
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -517,14 +521,16 @@ def load_checkpoint(executor, dirname=None, main_program=None):
|
|
|
|
|
executor,
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=main_program,
|
|
|
|
|
predicate=is_checkpoint_var,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_checkpoint_var(var):
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
|
|
|
"""
|
|
|
|
|
VarType will fliter out FEED_MINIBATCH FETCH_LIST RAW
|
|
|
|
|
VarName will fliter out Gradient
|
|
|
|
|
checkpoint will not save or load all the variables.
|
|
|
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW and var name is end with @GRAD are discarded.
|
|
|
|
|
|
|
|
|
|
:param var
|
|
|
|
|
"""
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
@ -545,9 +551,6 @@ def _interval_secs_exceed(dirname, save_interval_secs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lru_delete(dirname, max_num_checkpoints=3):
|
|
|
|
|
"""
|
|
|
|
|
retain checkpoint nums with max_num_checkpoints
|
|
|
|
|
"""
|
|
|
|
|
dirs = os.listdir(dirname)
|
|
|
|
|
serials = []
|
|
|
|
|
for serial in dirs:
|
|
|
|
@ -568,7 +571,9 @@ def _lru_delete(dirname, max_num_checkpoints=3):
|
|
|
|
|
|
|
|
|
|
def _write_success(dirname):
|
|
|
|
|
"""
|
|
|
|
|
write _SUCCESS to checkpoint dir
|
|
|
|
|
write an empty _SUCCESS file to checkpoint dir, indicate this checkpoint is correct.
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
"""
|
|
|
|
|
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
|
|
|
|
|
with open(success_file, 'a'):
|
|
|
|
@ -577,7 +582,9 @@ def _write_success(dirname):
|
|
|
|
|
|
|
|
|
|
def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
get the biggest number in checkpoint_dir, which has _SUCCESS
|
|
|
|
|
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
|
|
|
|
|
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
"""
|
|
|
|
|
if not checkpoint_dir.strip():
|
|
|
|
|
return -1
|
|
|
|
|