|
|
|
@ -491,7 +491,6 @@ CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
checkpoint_dir=None,
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
save_interval_secs=600,
|
|
|
|
|
main_program=None):
|
|
|
|
|
"""
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
@ -511,15 +510,10 @@ def save_checkpoint(executor,
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir)
|
|
|
|
|
if serial >= 0 and not _interval_secs_exceed(
|
|
|
|
|
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
serial += 1
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
serial = _get_lastest_checkpoint_dir(checkpoint_dir) + 1
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
|
|
|
|
|
load_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
_lru_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
@ -542,7 +536,7 @@ def load_checkpoint(executor, checkpoint_dir=None, main_program=None):
|
|
|
|
|
if serial < 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
cur_dir = _get_serial_dir(serial, checkpoint_dir)
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
load_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -559,11 +553,6 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
os.rmdir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_serial_dir(serial, checkpoint_dir):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
return os.path.join(checkpoint_dir, serial_folder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
@ -582,29 +571,37 @@ def _is_checkpoint_var(var):
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _interval_secs_exceed(dirname, save_interval_secs):
|
|
|
|
|
dir_time = os.path.getmtime(dirname)
|
|
|
|
|
if save_interval_secs > (time.time() - dir_time):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
def _get_dir_serial(dirname):
|
|
|
|
|
_, serial = dirname.split(CHECKPOINT_SEPARATOR)
|
|
|
|
|
|
|
|
|
|
serial_num = -1
|
|
|
|
|
try:
|
|
|
|
|
serial_num = int(serial)
|
|
|
|
|
except ValueError:
|
|
|
|
|
serial_num = -1
|
|
|
|
|
return serial_num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_serial_dir(dirname, serial):
|
|
|
|
|
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
|
|
|
|
|
return os.path.join(dirname, serial_folder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _lru_delete(dirname, max_num_checkpoints=3):
|
|
|
|
|
dirs = os.listdir(dirname)
|
|
|
|
|
serials = []
|
|
|
|
|
serial_map = {}
|
|
|
|
|
for serial in dirs:
|
|
|
|
|
try:
|
|
|
|
|
serials.append(int(serial))
|
|
|
|
|
except ValueError:
|
|
|
|
|
continue
|
|
|
|
|
serial_num = _get_dir_serial(serial)
|
|
|
|
|
serial_map[serial_num] = serial
|
|
|
|
|
|
|
|
|
|
if len(serials) <= max_num_checkpoints:
|
|
|
|
|
if len(serial_map.keys()) <= max_num_checkpoints:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
serials = serial_map.keys()
|
|
|
|
|
serials.sort(reverse=True)
|
|
|
|
|
serials = serials[max_num_checkpoints:]
|
|
|
|
|
for serial in serials:
|
|
|
|
|
cur_dir = os.path.join(dirname, str(serial))
|
|
|
|
|
cur_dir = _get_serial_dir(dirname, serial)
|
|
|
|
|
shutil.rmtree(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -633,20 +630,18 @@ def _get_lastest_checkpoint_dir(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
is _SUCCESS in this dir
|
|
|
|
|
"""
|
|
|
|
|
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
int(serial)
|
|
|
|
|
except ValueError:
|
|
|
|
|
serial = _get_dir_serial(cur_dir)
|
|
|
|
|
if serial == -1:
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(
|
|
|
|
|
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME)
|
|
|
|
|
_get_serial_dir(checkpoint_dir, serial), SUCCESS_MARK_FILENAME)
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return int(serial)
|
|
|
|
|
return serial
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
return -1
|
|
|
|
|