|
|
|
@ -188,7 +188,7 @@ class Trainer(object):
|
|
|
|
|
if not self.checkpoint.is_pserver:
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
|
|
|
|
|
self.trainer_id, ["epoch_id", "step_id"])
|
|
|
|
|
self.trainer_id, self._get_checkpoint_load_args())
|
|
|
|
|
self.checkpoint.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint.step_id = int(step_id)
|
|
|
|
|
|
|
|
|
@ -432,22 +432,33 @@ class Trainer(object):
|
|
|
|
|
return
|
|
|
|
|
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
def _get_checkpoint_load_args(self):
|
|
|
|
|
"""
|
|
|
|
|
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
|
|
|
|
|
"""
|
|
|
|
|
return ["epoch_id", "step_id"]
|
|
|
|
|
|
|
|
|
|
def _get_checkpoint_save_args(self, epoch_id, step_id):
|
|
|
|
|
"""
|
|
|
|
|
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
|
|
|
|
|
"""
|
|
|
|
|
trainer_args = {}
|
|
|
|
|
trainer_args["epoch_id"] = epoch_id
|
|
|
|
|
trainer_args["step_id"] = step_id
|
|
|
|
|
return trainer_args
|
|
|
|
|
|
|
|
|
|
def _save_checkpoint(self, epoch_id, step_id):
|
|
|
|
|
if not self.checkpoint:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
|
|
|
|
|
trainer_args = {}
|
|
|
|
|
trainer_args["epoch_id"] = epoch_id
|
|
|
|
|
trainer_args["step_id"] = step_id
|
|
|
|
|
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
io.save_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint.checkpoint_dir,
|
|
|
|
|
trainer_id=self.trainer_id,
|
|
|
|
|
is_chief=self.chief,
|
|
|
|
|
trainer_args=trainer_args,
|
|
|
|
|
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
|
|
|
|
|
main_program=self.train_program,
|
|
|
|
|
max_num_checkpoints=self.checkpoint.max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|