checkpoint api optimized

guochaorong-patch-1
tangwei12 7 years ago
parent 436bb4500b
commit 95545f7676

File diff suppressed because it is too large Load Diff

@ -277,31 +277,14 @@ class Trainer(object):
exe.run(self.startup_program)
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard():
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.pserver_id:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir,
self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
self._load_checkpoint()
if param_path and os.path.isdir(param_path):
# load params from param_path into scope
io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program)
io.load_persistables(
executor=exe,
dirname=param_path,
main_program=self.startup_program)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
@ -580,6 +563,42 @@ class Trainer(object):
main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def _load_checkpoint(self):
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program)
if not self.checkpoint_cfg.pserver_id:
load_trainer_args = self._get_checkpoint_load_args()
trainer_args = io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.trainer_id,
is_trainer=True,
load_trainer_args=load_trainer_args)
if len(trainer_args) != 2:
raise ValueError(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self.checkpoint_cfg.epoch_id = int(trainer_args[0])
self.checkpoint_cfg.step_id = int(trainer_args[1])
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
main_program=self.startup_program,
role_id=self.checkpoint_cfg.pserver_id,
is_trainer=False,
load_trainer_args=None,
load_lookup_table=self.checkpoint_cfg.lookup_table_name)
def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program):

Loading…
Cancel
Save