|
|
|
@ -455,10 +455,12 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUCCESS_MARK_FILENAME = "_SUCCESS"
|
|
|
|
|
CHECKPOINT_PREFIX = "checkpoint"
|
|
|
|
|
CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
dirname=None,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
save_interval_secs=600,
|
|
|
|
|
main_program=None):
|
|
|
|
@ -466,26 +468,27 @@ def save_checkpoint(executor,
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
|
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
|
|
|
|
|
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
|
|
|
|
|
The interval time between two save_checkpoint must great than or equal to save_interval_secs.
|
|
|
|
|
The interval between two saved checkpoints must greater than save_interval_secs.
|
|
|
|
|
|
|
|
|
|
:param dirname
|
|
|
|
|
:param executor
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param max_num_checkpoints
|
|
|
|
|
:param save_secs
|
|
|
|
|
:param save_interval_secs
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
if dirname is None:
|
|
|
|
|
dirname = os.getcwd()
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(dirname):
|
|
|
|
|
os.makedirs(dirname)
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(dirname)
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial >= 0 and not _interval_secs_exceed(
|
|
|
|
|
os.path.join(dirname, str(serial)), save_interval_secs):
|
|
|
|
|
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
serial = serial + 1
|
|
|
|
|
cur_dir = os.path.join(dirname, str(serial))
|
|
|
|
|
serial += 1
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
save_vars(
|
|
|
|
|
executor,
|
|
|
|
@ -495,27 +498,28 @@ def save_checkpoint(executor,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
filename=None)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
_lru_delete(dirname, max_num_checkpoints)
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(executor, dirname=None, main_program=None):
|
|
|
|
|
def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Load checkpoint from a directory by executor,
|
|
|
|
|
it will find latest checkpoint file and load it auto.
|
|
|
|
|
it will find the most recent saved checkpoint file and load it auto.
|
|
|
|
|
|
|
|
|
|
:param executor
|
|
|
|
|
:param dirname
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
|
:param main_program
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if dirname is None:
|
|
|
|
|
dirname = os.getcwd()
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
checkpoint_dir = os.getcwd()
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(dirname)
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return
|
|
|
|
|
cur_dir = os.path.join(dirname, str(serial))
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
load_vars(
|
|
|
|
|
executor,
|
|
|
|
@ -525,6 +529,11 @@ def load_checkpoint(executor, dirname=None, main_program=None):
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_serial_dir(serial, checkpoint_dir):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
return os.path.join(checkpoint_dir, serial_folder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
@ -577,7 +586,8 @@ def _write_success(dirname):
|
|
|
|
|
"""
|
|
|
|
|
success_file = os.path.join(dirname, SUCCESS_MARK_FILENAME)
|
|
|
|
|
with open(success_file, 'a'):
|
|
|
|
|
pass
|
|
|
|
|
now = time.ctime()
|
|
|
|
|
success_file.write(now)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
@ -593,18 +603,20 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
is _SUCCESS in this dir
|
|
|
|
|
"""
|
|
|
|
|
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
return -1
|
|
|
|
|
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
int(cur_dir)
|
|
|
|
|
int(serial)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(checkpoint_dir, cur_dir,
|
|
|
|
|
SUCCESS_MARK_FILENAME)
|
|
|
|
|
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(
|
|
|
|
|
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return int(cur_dir)
|
|
|
|
|
return int(serial)
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
return -1
|
|
|
|
|