|
|
|
@ -500,6 +500,7 @@ def save_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
|
|
if trainer_id == 0:
|
|
|
|
if trainer_id == 0:
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
save_persist_vars_without_grad(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
save_pserver_vars_by_notify(executor, cur_dir, "")
|
|
|
|
|
|
|
|
|
|
|
|
_scroll_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
_scroll_delete(checkpoint_dir, max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
@ -530,7 +531,8 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
|
|
|
|
|
|
|
|
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
|
|
clean the checkpoint dir, when the train exits normally,
|
|
|
|
|
|
|
|
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
|
|
|
|
|
|
|
:param checkpoint_dir
|
|
|
|
:param checkpoint_dir
|
|
|
|
@ -598,6 +600,23 @@ def save_persist_vars_without_grad(executor, dirname, program):
|
|
|
|
_write_success(cur_dir)
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_pserver_vars_by_notify(executor, dirname, epmap):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
cur_dir = _get_lookuptable_dir(dirname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_notify_program = Program()
|
|
|
|
|
|
|
|
checkpoint_notify_block = checkpoint_notify_program.global_block()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attrs = {}
|
|
|
|
|
|
|
|
attrs['epmap'] = None
|
|
|
|
|
|
|
|
attrs['dir'] = cur_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint_notify_block.append_op(
|
|
|
|
|
|
|
|
type='checkpointnotify', inputs={}, output={}, attrs=attrs)
|
|
|
|
|
|
|
|
executor.run(checkpoint_notify_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
def save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
|
|
|
|
|
|
|
|
@ -680,6 +699,15 @@ def _get_model_dir(dirname):
|
|
|
|
return model_dir
|
|
|
|
return model_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_lookuptable_dir(dirname):
|
|
|
|
|
|
|
|
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(lookuptable_dir):
|
|
|
|
|
|
|
|
os.makedirs(lookuptable_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return lookuptable_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_trainer_dir(dirname, trainer_id):
|
|
|
|
def _get_trainer_dir(dirname, trainer_id):
|
|
|
|
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
|
|
|
|
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
|
|
|
|
trainer_dir = os.path.join(dirname, trainer_folder)
|
|
|
|
trainer_dir = os.path.join(dirname, trainer_folder)
|
|
|
|
|