|
|
|
@ -235,6 +235,9 @@ class Trainer(object):
|
|
|
|
|
# config for checkpoint
|
|
|
|
|
# only chief worker will save variables
|
|
|
|
|
self.trainer_id = 0
|
|
|
|
|
self.pserver_id = None
|
|
|
|
|
self.pserver_endpoints = None
|
|
|
|
|
self.lookup_table_name = None
|
|
|
|
|
self.checkpoint_cfg = checkpoint_config
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
|
|
|
|
@ -282,10 +285,12 @@ class Trainer(object):
|
|
|
|
|
|
|
|
|
|
if param_path and os.path.isdir(param_path):
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
|
io.load_persistables(
|
|
|
|
|
executor=exe,
|
|
|
|
|
dirname=param_path,
|
|
|
|
|
main_program=self.startup_program)
|
|
|
|
|
_load_persistable_vars(exe, param_path, self.startup_program, False,
|
|
|
|
|
[self.lookup_table_name]
|
|
|
|
|
if self.lookup_table_name else [])
|
|
|
|
|
if self.lookup_table_name and self.pserver_id:
|
|
|
|
|
_load_lookup_table_vars(exe, param_path, self.startup_program,
|
|
|
|
|
self.pserver_id, self.lookup_table_name)
|
|
|
|
|
|
|
|
|
|
def _transpile_nccl2_dist(self):
|
|
|
|
|
# PADDLE_TRAINER_IPS
|
|
|
|
|