|
|
|
@ -277,31 +277,14 @@ class Trainer(object):
|
|
|
|
|
exe.run(self.startup_program)
|
|
|
|
|
|
|
|
|
|
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
|
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.checkpoint_cfg.load_serial,
|
|
|
|
|
self.startup_program)
|
|
|
|
|
|
|
|
|
|
if not self.checkpoint_cfg.pserver_id:
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.checkpoint_cfg.load_serial, self.trainer_id,
|
|
|
|
|
self._get_checkpoint_load_args())
|
|
|
|
|
self.checkpoint_cfg.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint_cfg.step_id = int(step_id)
|
|
|
|
|
else:
|
|
|
|
|
if self.checkpoint_cfg.lookup_table_name:
|
|
|
|
|
io.load_lookup_table_vars(
|
|
|
|
|
exe, self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.startup_program,
|
|
|
|
|
self.checkpoint_cfg.pserver_id,
|
|
|
|
|
self.checkpoint_cfg.lookup_table_name)
|
|
|
|
|
self._load_checkpoint()
|
|
|
|
|
|
|
|
|
|
if param_path and os.path.isdir(param_path):
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
|
io.load_persist_vars_without_grad(
|
|
|
|
|
exe, dirname=param_path, program=self.startup_program)
|
|
|
|
|
io.load_persistables(
|
|
|
|
|
executor=exe,
|
|
|
|
|
dirname=param_path,
|
|
|
|
|
main_program=self.startup_program)
|
|
|
|
|
|
|
|
|
|
def _transpile_nccl2_dist(self):
|
|
|
|
|
# PADDLE_TRAINER_IPS
|
|
|
|
@ -580,6 +563,42 @@ class Trainer(object):
|
|
|
|
|
main_program=self.train_program,
|
|
|
|
|
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
def _load_checkpoint(self):
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
io.load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
main_program=self.startup_program)
|
|
|
|
|
|
|
|
|
|
if not self.checkpoint_cfg.pserver_id:
|
|
|
|
|
load_trainer_args = self._get_checkpoint_load_args()
|
|
|
|
|
trainer_args = io.load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
main_program=self.startup_program,
|
|
|
|
|
role_id=self.trainer_id,
|
|
|
|
|
is_trainer=True,
|
|
|
|
|
load_trainer_args=load_trainer_args)
|
|
|
|
|
|
|
|
|
|
if len(trainer_args) != 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"the return trainer_args length do not equal _get_checkpoint_load_args"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
|
|
|
|
|
self.checkpoint_cfg.step_id = int(trainer_args[1])
|
|
|
|
|
else:
|
|
|
|
|
if self.checkpoint_cfg.lookup_table_name:
|
|
|
|
|
io.load_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
main_program=self.startup_program,
|
|
|
|
|
role_id=self.checkpoint_cfg.pserver_id,
|
|
|
|
|
is_trainer=False,
|
|
|
|
|
load_trainer_args=None,
|
|
|
|
|
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_feed_var_list(program, feed_order):
|
|
|
|
|
if not isinstance(program, framework.Program):
|
|
|
|
|