|
|
@ -560,6 +560,9 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
|
|
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
|
|
|
|
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
|
|
|
|
and step_id % self.checkpoint_cfg.step_interval == 0:
|
|
|
|
and step_id % self.checkpoint_cfg.step_interval == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("_save_checkpoint ...")
|
|
|
|
|
|
|
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
save_checkpoint(
|
|
|
|
save_checkpoint(
|
|
|
|
executor=exe,
|
|
|
|
executor=exe,
|
|
|
@ -661,12 +664,12 @@ CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
checkpoint_dir,
|
|
|
|
checkpoint_dir,
|
|
|
|
trainer_id,
|
|
|
|
main_program=None,
|
|
|
|
main_program,
|
|
|
|
trainer_id=0,
|
|
|
|
trainer_args=None,
|
|
|
|
save_trainer_args=None,
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
|
|
|
save_lookup_table=None,
|
|
|
|
save_lookup_table=None,
|
|
|
|
pserver_endpoints=None):
|
|
|
|
pserver_endpoints=None,
|
|
|
|
|
|
|
|
max_num_checkpoints=3):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This function filters out all checkpoint variables from the give
|
|
|
|
This function filters out all checkpoint variables from the give
|
|
|
|
main_program and then saves these variables to the `checkpoint_dir`
|
|
|
|
main_program and then saves these variables to the `checkpoint_dir`
|
|
|
@ -735,21 +738,18 @@ def save_checkpoint(executor,
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
|
|
if main_program is None:
|
|
|
|
|
|
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if trainer_args:
|
|
|
|
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_make_chekcpoint_dirs(checkpoint_dir)
|
|
|
|
_make_chekcpoint_dirs(checkpoint_dir)
|
|
|
|
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
|
|
|
|
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
|
|
|
|
|
|
|
|
|
|
|
|
_save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if save_trainer_args is not None:
|
|
|
|
|
|
|
|
_save_trainer_args(cur_dir, trainer_id, save_trainer_args)
|
|
|
|
|
|
|
|
|
|
|
|
if is_chief:
|
|
|
|
if is_chief:
|
|
|
|
|
|
|
|
if main_program is None:
|
|
|
|
|
|
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
_save_persistable_vars(executor, cur_dir, main_program)
|
|
|
|
_save_persistable_vars(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
|
|
if is_chief and save_lookup_table and pserver_endpoints:
|
|
|
|
if is_chief and save_lookup_table and pserver_endpoints:
|
|
|
@ -764,7 +764,7 @@ def load_checkpoint(executor,
|
|
|
|
main_program=None,
|
|
|
|
main_program=None,
|
|
|
|
role_id=0,
|
|
|
|
role_id=0,
|
|
|
|
is_trainer=True,
|
|
|
|
is_trainer=True,
|
|
|
|
load_models=True,
|
|
|
|
load_models=False,
|
|
|
|
load_trainer_args=None,
|
|
|
|
load_trainer_args=None,
|
|
|
|
load_slice_up_vars=None,
|
|
|
|
load_slice_up_vars=None,
|
|
|
|
load_lookup_table=None):
|
|
|
|
load_lookup_table=None):
|
|
|
@ -827,6 +827,10 @@ def load_checkpoint(executor,
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
if load_trainer_args:
|
|
|
|
if load_trainer_args:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
|
|
|
|
|
|
|
|
format(checkpoint_dir, role_id, load_trainer_args))
|
|
|
|
|
|
|
|
|
|
|
|
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
|
|
|
|
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
|
|
|
|
load_trainer_args)
|
|
|
|
load_trainer_args)
|
|
|
|
return trainer_args_ret
|
|
|
|
return trainer_args_ret
|
|
|
@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
|
|
|
|
|
|
|
: param checkpoint_dir
|
|
|
|
: param checkpoint_dir
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not checkpoint_dir:
|
|
|
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
serial = _get_dir_serial(cur_dir)
|
|
|
|
serial = _get_dir_serial(cur_dir)
|
|
|
|
if serial == -1 or not os.path.isdir(
|
|
|
|
if serial == -1 or \
|
|
|
|
os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
return -1
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(
|
|
|
|
success_path = os.path.join(
|
|
|
@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
return serial
|
|
|
|
return serial
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_dir = -1
|
|
|
|
current_dir = -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
|
|
|
|
|
|
|
|
return current_dir
|
|
|
|
|
|
|
|
|
|
|
|
dirs = os.listdir(checkpoint_dir)
|
|
|
|
dirs = os.listdir(checkpoint_dir)
|
|
|
|
for cur_dir in dirs:
|
|
|
|
for cur_dir in dirs:
|
|
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
|
|