|
|
|
@ -134,8 +134,6 @@ class CheckpointConfig(object):
|
|
|
|
|
self.epoch_id = 0
|
|
|
|
|
self.step_id = 0
|
|
|
|
|
self.load_serial = None
|
|
|
|
|
self.pserver_id = None
|
|
|
|
|
self.lookup_table_name = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_get_place(place):
|
|
|
|
@ -351,11 +349,9 @@ class Trainer(object):
|
|
|
|
|
t.transpile(
|
|
|
|
|
self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
pserver_id = eplist.index(current_endpoint)
|
|
|
|
|
self.checkpoint_cfg.pserver_id = pserver_id
|
|
|
|
|
if t.has_distributed_lookup_table:
|
|
|
|
|
self.checkpoint_cfg.lookup_table_name = t.table_name
|
|
|
|
|
self.pserver_id = eplist.index(current_endpoint)
|
|
|
|
|
self.pserver_endpoints = pserver_endpoints
|
|
|
|
|
self.lookup_table_name = t.table_name if t.has_distributed_lookup_table else None
|
|
|
|
|
|
|
|
|
|
self.train_program = t.get_pserver_program(current_endpoint)
|
|
|
|
|
self.startup_program = t.get_startup_program(current_endpoint,
|
|
|
|
@ -417,6 +413,11 @@ class Trainer(object):
|
|
|
|
|
def save_params(self, param_path):
|
|
|
|
|
"""
|
|
|
|
|
Save all parameters into :code:`param_path`.
|
|
|
|
|
Only No.0 trainer will save dense params.
|
|
|
|
|
In standalone PaddlePaddle, the only existing trainer will save dense params.
|
|
|
|
|
In distributed PaddlePaddle, the No.0 trainer will save dense params,
|
|
|
|
|
If there have lookup table need to save, No.0 trainer will broadcast notification
|
|
|
|
|
to all Parameter Servers to save it on Parameter Servers independent.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
param_path(str): The path to save parameters.
|
|
|
|
@ -424,9 +425,19 @@ class Trainer(object):
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if self.trainer_id != 0:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
|
# save params on trainer
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
io.save_persistables(exe, dirname=param_path)
|
|
|
|
|
# save params on pserver
|
|
|
|
|
if self.lookup_table_name:
|
|
|
|
|
_save_pserver_vars_by_notify(exe, param_path,
|
|
|
|
|
self.lookup_table_name,
|
|
|
|
|
self.pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def _prog_and_scope_guard(self):
|
|
|
|
@ -560,15 +571,16 @@ class Trainer(object):
|
|
|
|
|
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
|
|
|
|
|
and step_id % self.checkpoint_cfg.step_interval == 0:
|
|
|
|
|
|
|
|
|
|
print("_save_checkpoint ...")
|
|
|
|
|
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
save_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
trainer_id=self.trainer_id,
|
|
|
|
|
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
|
|
|
|
|
main_program=self.train_program,
|
|
|
|
|
trainer_id=self.trainer_id,
|
|
|
|
|
save_trainer_args=self._get_checkpoint_save_args(epoch_id,
|
|
|
|
|
step_id),
|
|
|
|
|
save_lookup_table=self.lookup_table_name,
|
|
|
|
|
pserver_endpoints=self.pserver_endpoints,
|
|
|
|
|
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
def _load_checkpoint(self):
|
|
|
|
@ -579,7 +591,7 @@ class Trainer(object):
|
|
|
|
|
self.checkpoint_cfg.load_serial)
|
|
|
|
|
|
|
|
|
|
# Trainer Load
|
|
|
|
|
if self.checkpoint_cfg.pserver_id is None:
|
|
|
|
|
if self.pserver_id is None:
|
|
|
|
|
# load model
|
|
|
|
|
load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
@ -608,15 +620,25 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
# Pserver Load
|
|
|
|
|
else:
|
|
|
|
|
# load model
|
|
|
|
|
load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=checkpoint_dir,
|
|
|
|
|
main_program=self.startup_program,
|
|
|
|
|
role_id=self.pserver_id,
|
|
|
|
|
is_trainer=False,
|
|
|
|
|
load_models=True,
|
|
|
|
|
load_lookup_table=self.lookup_table_name)
|
|
|
|
|
|
|
|
|
|
# load lookup table
|
|
|
|
|
if self.checkpoint_cfg.lookup_table_name:
|
|
|
|
|
if self.lookup_table_name:
|
|
|
|
|
load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=checkpoint_dir,
|
|
|
|
|
main_program=self.startup_program,
|
|
|
|
|
role_id=self.checkpoint_cfg.pserver_id,
|
|
|
|
|
role_id=self.pserver_id,
|
|
|
|
|
is_trainer=False,
|
|
|
|
|
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
|
|
|
|
|
load_lookup_table=self.lookup_table_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_feed_var_list(program, feed_order):
|
|
|
|
@ -813,13 +835,21 @@ def load_checkpoint(executor,
|
|
|
|
|
if is_trainer:
|
|
|
|
|
if load_models:
|
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program, True)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if load_trainer_args:
|
|
|
|
|
trainer_args_ret = _load_trainer_args(checkpoint_dir, role_id,
|
|
|
|
|
load_trainer_args)
|
|
|
|
|
return trainer_args_ret
|
|
|
|
|
# pserver load
|
|
|
|
|
else:
|
|
|
|
|
if load_models:
|
|
|
|
|
if load_lookup_table:
|
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program,
|
|
|
|
|
True, [load_lookup_table])
|
|
|
|
|
else:
|
|
|
|
|
_load_persistable_vars(executor, checkpoint_dir, main_program,
|
|
|
|
|
True)
|
|
|
|
|
|
|
|
|
|
if load_lookup_table:
|
|
|
|
|
_load_lookup_table_vars(executor, checkpoint_dir, main_program,
|
|
|
|
|
role_id, load_lookup_table)
|
|
|
|
@ -843,7 +873,11 @@ def clean_checkpoint(checkpoint_dir, delete_dir=False):
|
|
|
|
|
os.rmdir(checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_persistable_vars(executor, dirname, program, has_model_dir=False):
|
|
|
|
|
def _load_persistable_vars(executor,
|
|
|
|
|
dirname,
|
|
|
|
|
program,
|
|
|
|
|
has_model_dir=False,
|
|
|
|
|
except_vars=None):
|
|
|
|
|
"""
|
|
|
|
|
This function filters out all checkpoint variables from the give
|
|
|
|
|
program and then trys to load these variables from the given directory.
|
|
|
|
@ -888,7 +922,7 @@ def _load_persistable_vars(executor, dirname, program, has_model_dir=False):
|
|
|
|
|
executor,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
main_program=program,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
predicate=_is_checkpoint_var(except_vars),
|
|
|
|
|
filename=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -983,13 +1017,13 @@ def _save_persistable_vars(executor, dirname, program):
|
|
|
|
|
dirname=cur_dir,
|
|
|
|
|
main_program=program,
|
|
|
|
|
vars=None,
|
|
|
|
|
predicate=_is_checkpoint_var,
|
|
|
|
|
predicate=_is_checkpoint_var(),
|
|
|
|
|
filename=None)
|
|
|
|
|
_write_success(cur_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
ps_endpoint_list):
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
"""
|
|
|
|
|
This function will send checkpoint notify message from Trainer 0
|
|
|
|
|
to all the pservers.
|
|
|
|
@ -1002,8 +1036,8 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
lookup_table(string): the lookup table name, when use distribute
|
|
|
|
|
lookup table, we can get lookup table name by DistributeTranspiler.
|
|
|
|
|
table_name
|
|
|
|
|
ps_endpoint_list(list): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get ps_endpoint_list by
|
|
|
|
|
pserver_endpoints(list): the parameter server ip:port list.
|
|
|
|
|
when use distribute lookup table, we can get pserver_endpoints by
|
|
|
|
|
distribute arguments.
|
|
|
|
|
Return:
|
|
|
|
|
None
|
|
|
|
@ -1027,7 +1061,7 @@ def _save_pserver_vars_by_notify(executor, dirname, lookup_table,
|
|
|
|
|
checkpoint_notify_block = checkpoint_notify_program.global_block()
|
|
|
|
|
|
|
|
|
|
attrs = {}
|
|
|
|
|
attrs['epmap'] = ps_endpoint_list
|
|
|
|
|
attrs['epmap'] = pserver_endpoints.split(",")
|
|
|
|
|
attrs['dir'] = cur_dir
|
|
|
|
|
attrs['lookup_table'] = lookup_table
|
|
|
|
|
|
|
|
|
@ -1086,29 +1120,37 @@ def _load_trainer_args(checkpoint_dir, trainer_id, trainer_args):
|
|
|
|
|
return ret_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_checkpoint_var(var):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
|
|
|
|
|
def _is_checkpoint_var(except_vars=None):
|
|
|
|
|
except_vars = [] if except_vars is None else except_vars
|
|
|
|
|
|
|
|
|
|
: param var(Variable)
|
|
|
|
|
"""
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
|
|
|
return False
|
|
|
|
|
# @GRAD are named for gradient variables, checkpoint will not save it.
|
|
|
|
|
if "@GRAD" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
# .trainer_ are named for distribute train variables, checkpoint will not save it.
|
|
|
|
|
if ".trainer_" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# .block is named for distribute train variables, checkpoint will not save it.
|
|
|
|
|
if ".block" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return var.persistable
|
|
|
|
|
def _except_vars(var):
|
|
|
|
|
"""
|
|
|
|
|
the checkpoint will not save or load all the variables.
|
|
|
|
|
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
|
|
|
|
|
|
|
|
|
|
: param var(Variable)
|
|
|
|
|
"""
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
|
|
|
|
|
var.desc.type() == core.VarDesc.VarType.RAW:
|
|
|
|
|
return False
|
|
|
|
|
# @GRAD are named for gradient variables, checkpoint will not save it.
|
|
|
|
|
if "@GRAD" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
# .trainer_ are named for distribute train variables, checkpoint will not save it.
|
|
|
|
|
if ".trainer_" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# .block is named for distribute train variables, checkpoint will not save it.
|
|
|
|
|
if ".block" in var.name:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if var in except_vars:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
return var.persistable
|
|
|
|
|
|
|
|
|
|
return _except_vars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_chekcpoint_dirs(dirs):
|
|
|
|
|