|
|
|
@ -356,10 +356,14 @@ class Trainer(object):
|
|
|
|
|
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
|
|
|
|
|
|
|
|
|
|
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
|
|
|
|
|
epochs = [
|
|
|
|
|
epoch_id for epoch_id in range(num_epochs)
|
|
|
|
|
if epoch_id >= self.checkpoint.epoch_id
|
|
|
|
|
]
|
|
|
|
|
if self.checkpoint:
|
|
|
|
|
epochs = [
|
|
|
|
|
epoch_id for epoch_id in range(num_epochs)
|
|
|
|
|
if epoch_id >= self.checkpoint.epoch_id
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
epochs = [epoch_id for epoch_id in range(num_epochs)]
|
|
|
|
|
|
|
|
|
|
for epoch_id in epochs:
|
|
|
|
|
event_handler(BeginEpochEvent(epoch_id))
|
|
|
|
|
for step_id, data in enumerate(reader()):
|
|
|
|
|