|
|
|
@ -73,7 +73,7 @@ class BeginStepEvent(object):
|
|
|
|
|
self.step = step_id
|
|
|
|
|
self.fetch_metrics = True
|
|
|
|
|
"""
|
|
|
|
|
If fetch_metrics is true, the metrics will be fetched at the
|
|
|
|
|
If fetch_metrics is true, the metrics will be fetched at the
|
|
|
|
|
EndStepEvent. Default is True.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -560,6 +560,9 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
|
|
|
|
|
and step_id % self.checkpoint_cfg.step_interval == 0:
|
|
|
|
|
|
|
|
|
|
print("_save_checkpoint ...")
|
|
|
|
|
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
save_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
@ -604,7 +607,7 @@ class Trainer(object):
|
|
|
|
|
self.checkpoint_cfg.epoch_id = int(trainer_args_ret[0])
|
|
|
|
|
self.checkpoint_cfg.step_id = int(trainer_args_ret[1])
|
|
|
|
|
|
|
|
|
|
# Pserver Load
|
|
|
|
|
# Pserver Load
|
|
|
|
|
else:
|
|
|
|
|
# load slice_vars
|
|
|
|
|
if self.slice_vars != None and len(self.slice_vars) != 0:
|
|
|
|
@ -661,22 +664,22 @@ CHECKPOINT_SEPARATOR = "_"
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(executor,
|
|
|
|
|
checkpoint_dir,
|
|
|
|
|
trainer_id,
|
|
|
|
|
main_program,
|
|
|
|
|
trainer_args=None,
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
main_program=None,
|
|
|
|
|
trainer_id=0,
|
|
|
|
|
save_trainer_args=None,
|
|
|
|
|
save_lookup_table=None,
|
|
|
|
|
pserver_endpoints=None):
|
|
|
|
|
pserver_endpoints=None,
|
|
|
|
|
max_num_checkpoints=3):
|
|
|
|
|
"""
|
|
|
|
|
This function filters out all checkpoint variables from the give
|
|
|
|
|
main_program and then saves these variables to the `checkpoint_dir`
|
|
|
|
|
main_program and then saves these variables to the `checkpoint_dir`
|
|
|
|
|
directory.
|
|
|
|
|
|
|
|
|
|
In the training precess, we generally save a checkpoint in each
|
|
|
|
|
iteration. So there might be a lot of checkpoints in the
|
|
|
|
|
`checkpoint_dir`. To avoid them taking too much disk space, the
|
|
|
|
|
`max_num_checkpoints` are introduced to limit the total number of
|
|
|
|
|
checkpoints. If the number of existing checkpints is greater than
|
|
|
|
|
iteration. So there might be a lot of checkpoints in the
|
|
|
|
|
`checkpoint_dir`. To avoid them taking too much disk space, the
|
|
|
|
|
`max_num_checkpoints` are introduced to limit the total number of
|
|
|
|
|
checkpoints. If the number of existing checkpints is greater than
|
|
|
|
|
the `max_num_checkpoints`, oldest ones will be scroll deleted.
|
|
|
|
|
|
|
|
|
|
A variable is a checkpoint variable and will be saved if it meets
|
|
|
|
@ -688,21 +691,21 @@ def save_checkpoint(executor,
|
|
|
|
|
Args:
|
|
|
|
|
executor(Executor): The executor to run for save checkpoint.
|
|
|
|
|
checkpoint_dir(str): The folder where to save checkpoints.
|
|
|
|
|
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
|
|
|
|
|
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
|
|
|
|
|
is chief.
|
|
|
|
|
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
|
|
|
|
|
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
|
|
|
|
|
and 'step_id'.
|
|
|
|
|
Defaut: None
|
|
|
|
|
main_program(Program): The program whose checkpoint variables will
|
|
|
|
|
be saved.
|
|
|
|
|
max_num_checkpoints(int): The max number of total number of existing
|
|
|
|
|
max_num_checkpoints(int): The max number of total number of existing
|
|
|
|
|
checkpoints.
|
|
|
|
|
Default: 3
|
|
|
|
|
save_lookup_table(string|None): the lookup table name, when use distribute
|
|
|
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
|
|
|
table_name
|
|
|
|
|
pserver_endpoints(list|None): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get pserver_endpoints by
|
|
|
|
|
table_name
|
|
|
|
|
pserver_endpoints(list|None): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get pserver_endpoints by
|
|
|
|
|
distribute arguments.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
@ -735,21 +738,18 @@ def save_checkpoint(executor,
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
raise ValueError("'checkpoint_dir' should not be None")
|
|
|
|
|
|
|
|
|
|
if main_program is None:
|
|
|
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
|
|
|
|
|
|
if trainer_args:
|
|
|
|
|
assert isinstance(trainer_args, dict)
|
|
|
|
|
|
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
|
|
|
|
|
|
_make_chekcpoint_dirs(checkpoint_dir)
|
|
|
|
|
serial = _get_latest_checkpoint_serial(checkpoint_dir) + 1
|
|
|
|
|
cur_dir = _get_serial_dir(checkpoint_dir, serial, True)
|
|
|
|
|
|
|
|
|
|
_save_trainer_args(cur_dir, trainer_id, trainer_args)
|
|
|
|
|
is_chief = trainer_id == 0
|
|
|
|
|
|
|
|
|
|
if save_trainer_args is not None:
|
|
|
|
|
_save_trainer_args(cur_dir, trainer_id, save_trainer_args)
|
|
|
|
|
|
|
|
|
|
if is_chief:
|
|
|
|
|
if main_program is None:
|
|
|
|
|
raise ValueError('main_program should not be None.')
|
|
|
|
|
_save_persistable_vars(executor, cur_dir, main_program)
|
|
|
|
|
|
|
|
|
|
if is_chief and save_lookup_table and pserver_endpoints:
|
|
|
|
@ -764,7 +764,7 @@ def load_checkpoint(executor,
|
|
|
|
|
main_program=None,
|
|
|
|
|
role_id=0,
|
|
|
|
|
is_trainer=True,
|
|
|
|
|
load_models=True,
|
|
|
|
|
load_models=False,
|
|
|
|
|
load_trainer_args=None,
|
|
|
|
|
load_slice_up_vars=None,
|
|
|
|
|
load_lookup_table=None):
|
|
|
|
@ -774,8 +774,8 @@ def load_checkpoint(executor,
|
|
|
|
|
`checkpoint_dir` directory.
|
|
|
|
|
|
|
|
|
|
In the training precess, we generally save a checkpoint in each
|
|
|
|
|
iteration. So there are more than one checkpoint in the
|
|
|
|
|
`checkpoint_dir` (each checkpoint has its own sub folder), use
|
|
|
|
|
iteration. So there are more than one checkpoint in the
|
|
|
|
|
`checkpoint_dir` (each checkpoint has its own sub folder), use
|
|
|
|
|
`serial` to specify which serial of checkpoint you would like to
|
|
|
|
|
load.
|
|
|
|
|
|
|
|
|
@ -827,6 +827,10 @@ def load_checkpoint(executor,
|
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
|
|
|
|
|
return
|
|
|
|
|
if load_trainer_args:
|
|
|
|
|
|
|
|
|
|
print("checkpoint_dir: {}, role_id: {}, load_trainer_args: {}".
|
|
|
|
|
format(checkpoint_dir, role_id, load_trainer_args))
|
|
|
|
|
|
|
|
|
|
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
|
|
|
|
|
load_trainer_args)
|
|
|
|
|
return trainer_args_ret
|
|
|
|
@ -842,9 +846,9 @@ def load_checkpoint(executor,
|
|
|
|
|
|
|
|
|
|
def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
"""
|
|
|
|
|
clean the checkpoint dir, when the train exits normally,
|
|
|
|
|
clean the checkpoint dir, when the train exits normally,
|
|
|
|
|
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
|
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
delete_dir only works when the directory is empty, otherwise, OSError is raised.
|
|
|
|
|
|
|
|
|
|
: param checkpoint_dir
|
|
|
|
|
: param delete_dir
|
|
|
|
@ -954,7 +958,7 @@ def _load_slice_up_vars(executor, dirname, slice_vars):
|
|
|
|
|
|
|
|
|
|
def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
|
|
|
|
|
"""
|
|
|
|
|
The parameter server will load lookup table's local file in
|
|
|
|
|
The parameter server will load lookup table's local file in
|
|
|
|
|
selectedrows variable.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -1005,7 +1009,7 @@ def _load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
|
|
|
|
|
def _save_persistable_vars(executor, dirname, program):
|
|
|
|
|
"""
|
|
|
|
|
This function filters out all checkpoint variables from the give
|
|
|
|
|
program and then save these variables to a sub-folder '__model__' of
|
|
|
|
|
program and then save these variables to a sub-folder '__model__' of
|
|
|
|
|
the given directory.
|
|
|
|
|
|
|
|
|
|
A variable is a checkpoint variable if it meets all following
|
|
|
|
@ -1034,7 +1038,7 @@ def _save_persistable_vars(executor, dirname, program):
|
|
|
|
|
|
|
|
|
|
# In this example, `_save_persistable_vars` function
|
|
|
|
|
# will first filters out all checkpoint variables in the default
|
|
|
|
|
# main program, and then saves these variables to the folder
|
|
|
|
|
# main program, and then saves these variables to the folder
|
|
|
|
|
# "./my_paddle_model/__model__".
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_model_dir(dirname)
|
|
|
|
@ -1053,7 +1057,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
"""
|
|
|
|
|
This function will send checkpoint notify message from Trainer 0
|
|
|
|
|
to all the pservers.
|
|
|
|
|
The checkpoint notify message contains lookup table name,
|
|
|
|
|
The checkpoint notify message contains lookup table name,
|
|
|
|
|
the absolute path on pserver to save lookup_table.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -1061,13 +1065,13 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
dirname(str): The folder where to save checkpoints.
|
|
|
|
|
lookup_table(string): the lookup table name, when use distribute
|
|
|
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
|
|
|
table_name
|
|
|
|
|
ps_endpoint_list(list): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
|
|
|
table_name
|
|
|
|
|
ps_endpoint_list(list): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
|
|
|
distribute arguments.
|
|
|
|
|
Return:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
@ -1078,7 +1082,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
|
|
|
|
|
|
|
|
|
|
_save_pserver_vars_by_notify(executor=exe,
|
|
|
|
|
dirname=param_path, lookup_table=table_name,
|
|
|
|
|
dirname=param_path, lookup_table=table_name,
|
|
|
|
|
ps_endpoint_list=ps_endpoints)
|
|
|
|
|
"""
|
|
|
|
|
cur_dir = _get_lookuptable_dir(dirname)
|
|
|
|
@ -1110,7 +1114,7 @@ def _save_trainer_args(dirname, trainer_id, trainer_args):
|
|
|
|
|
|
|
|
|
|
def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args):
|
|
|
|
|
"""
|
|
|
|
|
trainer will load some args from it's independent directory,
|
|
|
|
|
trainer will load some args from it's independent directory,
|
|
|
|
|
such as epoch_id and step_id.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -1264,8 +1268,6 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
|
|
|
|
|
: param checkpoint_dir
|
|
|
|
|
"""
|
|
|
|
|
if not checkpoint_dir:
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
def has_success(checkpoint_dir, cur_dir):
|
|
|
|
|
"""
|
|
|
|
@ -1273,8 +1275,8 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
serial = _get_dir_serial(cur_dir)
|
|
|
|
|
if serial == -1 or not os.path.isdir(
|
|
|
|
|
os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
if serial == -1 or \
|
|
|
|
|
not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
success_path = os.path.join(
|
|
|
|
@ -1283,10 +1285,11 @@ def _get_latest_checkpoint_serial(checkpoint_dir):
|
|
|
|
|
if os.path.isfile(success_path):
|
|
|
|
|
return serial
|
|
|
|
|
|
|
|
|
|
if not os.path.isdir(checkpoint_dir):
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
current_dir = -1
|
|
|
|
|
|
|
|
|
|
if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
|
|
|
|
|
return current_dir
|
|
|
|
|
|
|
|
|
|
dirs = os.listdir(checkpoint_dir)
|
|
|
|
|
for cur_dir in dirs:
|
|
|
|
|
success_num = has_success(checkpoint_dir, cur_dir)
|
|
|
|
|