|
|
|
@ -471,7 +471,10 @@ def save_checkpoint(executor,
|
|
|
|
|
trainer_id,
|
|
|
|
|
trainer_args=None,
|
|
|
|
|
main_program=None,
|
|
|
|
|
max_num_checkpoints=3):
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
lookup_table=None,
|
|
|
|
|
ps_endpoint_list=None
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
|
|
|
|
|
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
|
|
|
|
@ -500,7 +503,7 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
if trainer_id == 0:
|
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
save_pserver_vars_by_notify(executor, cur_dir, "")
|
|
|
|
|
save_pserver_vars_by_notify(executor, cur_dir, ps_endpoint_list, lookup_table)
|
|
|
|
|
|
|
|
|
|
_scroll_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
@ -600,7 +603,7 @@ def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_pserver_vars_by_notify(executor, dirname, epmap):
|
|
|
|
|
def save_pserver_vars_by_notify(executor, dirname, lookup_table, ps_endpoint_list):
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_lookuptable_dir(dirname)
|
|
|
|
@ -609,11 +612,12 @@ def save_pserver_vars_by_notify(executor, dirname, epmap):
|
|
|
|
|
checkpoint_notify_block = checkpoint_notify_program.global_block()
|
|
|
|
|
|
|
|
|
|
attrs = {}
|
|
|
|
|
attrs['epmap'] = None
|
|
|
|
|
attrs['epmap'] = ps_endpoint_list
|
|
|
|
|
attrs['dir'] = cur_dir
|
|
|
|
|
attrs['lookup_table'] = lookup_table
|
|
|
|
|
|
|
|
|
|
checkpoint_notify_block.append_op(
|
|
|
|
|
type='checkpoint_notify', inputs={}, output={}, attrs=attrs)
|
|
|
|
|
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
|
|
|
|
|
executor.run(checkpoint_notify_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -783,3 +787,4 @@ def get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
if success_num > current_dir:
|
|
|
|
|
current_dir = success_num
|
|
|
|
|
return current_dir
|
|
|
|
|
|
|
|
|
|