|
|
|
@ -472,8 +472,7 @@ def save_checkpoint(executor,
|
|
|
|
|
main_program=None,
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
lookup_table=None,
|
|
|
|
|
ps_endpoint_list=None
|
|
|
|
|
):
|
|
|
|
|
ps_endpoint_list=None):
|
|
|
|
|
"""
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
|
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
|
|
|
|
@ -495,14 +494,18 @@ def save_checkpoint(executor,
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
os.makedirs(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
|
|
|
|
|
|
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial)
|
|
|
|
|
|
|
|
|
|
save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
|
|
|
|
|
|
if trainer_id == 0:
|
|
|
|
|
if is_chief:
|
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
save_pserver_vars_by_notify(executor, cur_dir, lookup_table, ps_endpoint_list)
|
|
|
|
|
if is_chief and lookup_table and ps_endpoint_list:
|
|
|
|
|
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
|
|
|
|
|
ps_endpoint_list)
|
|
|
|
|
|
|
|
|
|
_scroll_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
@ -618,7 +621,8 @@ def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list):
|
|
|
|
|
def save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
ps_endpoint_list):
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_lookuptable_dir(dirname)
|
|
|
|
@ -802,4 +806,3 @@ def get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
if success_num > current_dir:
|
|
|
|
|
current_dir = success_num
|
|
|
|
|
return current_dir
|
|
|
|
|
|
|
|
|
|