|
|
|
@ -23,7 +23,7 @@ from . import core
|
|
|
|
|
__all__ = [
|
|
|
|
|
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
|
|
|
|
|
'load_persistables', 'save_inference_model', 'load_inference_model',
|
|
|
|
|
'get_inference_program', 'save_checkpoint', 'restore_checkpoint'
|
|
|
|
|
'get_inference_program', 'save_checkpoint', 'load_checkpoint'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -466,7 +466,7 @@ def save_checkpoint(executor,
|
|
|
|
|
Save Variables to Checkpoint Directory
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
:param keep_max
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
:param save_secs
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
@ -495,7 +495,7 @@ def save_checkpoint(executor,
|
|
|
|
|
_lru_delete(dirname, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def restore_checkpoint(executor, dirname=None, main_program=None):
|
|
|
|
|
def load_checkpoint(executor, dirname=None, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Load Variables from Checkpint Dir
|
|
|
|
|
|
|
|
|
@ -544,9 +544,9 @@ def _interval_secs_exceed(dirname, save_interval_secs):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lru_delete(dirname, keep_max=3):
|
|
|
|
|
def _lru_delete(dirname, max_num_checkpoints=3):
|
|
|
|
|
"""
|
|
|
|
|
retain checkpoint nums with keep_max
|
|
|
|
|
retain checkpoint nums with max_num_checkpoints
|
|
|
|
|
"""
|
|
|
|
|
dirs = os.listdir(dirname)
|
|
|
|
|
serials = []
|
|
|
|
@ -556,11 +556,11 @@ def _lru_delete(dirname, keep_max=3):
|
|
|
|
|
except ValueError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if len(serials) <= keep_max:
|
|
|
|
|
if len(serials) <= max_num_checkpoints:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
serials.sort(reverse=True)
|
|
|
|
|
serials = serials[keep_max:]
|
|
|
|
|
serials = serials[max_num_checkpoints:]
|
|
|
|
|
for serial in serials:
|
|
|
|
|
cur_dir = os.path.join(dirname, str(serial))
|
|
|
|
|
shutil.rmtree(cur_dir)
|
|
|
|
|