|
|
|
@ -139,13 +139,14 @@ class Trainer(object):
|
|
|
|
self.trainer_id = 0
|
|
|
|
self.trainer_id = 0
|
|
|
|
self.chief = True
|
|
|
|
self.chief = True
|
|
|
|
self.checkpoint = checkpoint_config
|
|
|
|
self.checkpoint = checkpoint_config
|
|
|
|
if self.checkpoint and \
|
|
|
|
if self.checkpoint:
|
|
|
|
not isinstance(self.checkpoint, CheckpointConfig):
|
|
|
|
if not isinstance(self.checkpoint, CheckpointConfig):
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
"The checkpoint_config shoule be an instance of CheckpointConfig"
|
|
|
|
"The checkpoint_config shoule be an instance of CheckpointConfig"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.checkpoint._load_serial = io.need_load_checkpoint(
|
|
|
|
else:
|
|
|
|
self.checkpoint.checkpoint_dir)
|
|
|
|
self.checkpoint._load_serial = io.need_load_checkpoint(
|
|
|
|
|
|
|
|
self.checkpoint.checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
|
|
self.scope = core.Scope()
|
|
|
|
self.scope = core.Scope()
|
|
|
|
|
|
|
|
|
|
|
|
@ -175,7 +176,7 @@ class Trainer(object):
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
exe.run(self.startup_program)
|
|
|
|
exe.run(self.startup_program)
|
|
|
|
|
|
|
|
|
|
|
|
if self.checkpoint._load_serial:
|
|
|
|
if self.checkpoint and self.checkpoint._load_serial:
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
|
|
|
|
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
|
|
|
|
self.checkpoint._load_serial,
|
|
|
|
self.checkpoint._load_serial,
|
|
|
|
|