|
|
|
@ -141,11 +141,7 @@ class Trainer(object):
|
|
|
|
|
self.chief = True
|
|
|
|
|
self.checkpoint = checkpoint_config
|
|
|
|
|
if self.checkpoint:
|
|
|
|
|
if not isinstance(self.checkpoint, CheckpointConfig):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"The checkpoint_config shoule be an instance of CheckpointConfig"
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(self.checkpoint, CheckpointConfig)
|
|
|
|
|
serial = io.get_latest_checkpoint_serial(
|
|
|
|
|
self.checkpoint.checkpoint_dir)
|
|
|
|
|
self.checkpoint.load_serial = serial if serial >= 0 else None
|
|
|
|
@ -385,8 +381,8 @@ class Trainer(object):
|
|
|
|
|
else:
|
|
|
|
|
metrics = exe.run(feed=data, fetch_list=[])
|
|
|
|
|
|
|
|
|
|
event_handler(EndStepEvent(epoch_id, step_id, metrics))
|
|
|
|
|
self._save_checkpoint(epoch_id, step_id)
|
|
|
|
|
event_handler(EndStepEvent(epoch_id, step_id, metrics))
|
|
|
|
|
event_handler(EndEpochEvent(epoch_id))
|
|
|
|
|
self._clean_checkpoint()
|
|
|
|
|
|
|
|
|
|