|
|
|
@ -81,7 +81,8 @@ class CheckpointConfig(object):
|
|
|
|
|
|
|
|
|
|
self.epoch_id = 0
|
|
|
|
|
self.step_id = 0
|
|
|
|
|
self._load_serial = None
|
|
|
|
|
self.load_serial = None
|
|
|
|
|
self.is_pserver = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_get_place(place):
|
|
|
|
@ -145,7 +146,7 @@ class Trainer(object):
|
|
|
|
|
"The checkpoint_config shoule be an instance of CheckpointConfig"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.checkpoint._load_serial = io.need_load_checkpoint(
|
|
|
|
|
self.checkpoint.load_serial = io.need_load_checkpoint(
|
|
|
|
|
self.checkpoint.checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
self.scope = core.Scope()
|
|
|
|
@ -176,17 +177,18 @@ class Trainer(object):
|
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
|
exe.run(self.startup_program)
|
|
|
|
|
|
|
|
|
|
if self.checkpoint and self.checkpoint._load_serial:
|
|
|
|
|
if self.checkpoint and self.checkpoint.load_serial:
|
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
|
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
|
|
|
|
|
self.checkpoint._load_serial,
|
|
|
|
|
self.checkpoint.load_serial,
|
|
|
|
|
self.startup_program)
|
|
|
|
|
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint.checkpoint_dir, self.checkpoint._load_serial,
|
|
|
|
|
self.trainer_id, ["epoch_id", "step_id"])
|
|
|
|
|
self.checkpoint.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint.step_id = int(step_id)
|
|
|
|
|
if not self.checkpoint.is_pserver:
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
|
|
|
|
|
self.trainer_id, ["epoch_id", "step_id"])
|
|
|
|
|
self.checkpoint.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint.step_id = int(step_id)
|
|
|
|
|
|
|
|
|
|
if param_path and os.path.isdir(param_path):
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
@ -259,6 +261,9 @@ class Trainer(object):
|
|
|
|
|
t.transpile(
|
|
|
|
|
trainer_id, pservers=pserver_endpoints, trainers=trainers)
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
if self.checkpoint:
|
|
|
|
|
self.is_pserver = True
|
|
|
|
|
|
|
|
|
|
self.train_program = t.get_pserver_program(current_endpoint)
|
|
|
|
|
self.startup_program = t.get_startup_program(current_endpoint,
|
|
|
|
|
self.train_program)
|
|
|
|
@ -362,7 +367,7 @@ class Trainer(object):
|
|
|
|
|
self._clean_checkpoint()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if self.checkpoint and self.checkpoint._load_serial \
|
|
|
|
|
if self.checkpoint and self.checkpoint.load_serial \
|
|
|
|
|
and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|