wangkuiyi-patch-1
tangwei12 7 years ago
parent bca4da4225
commit 46f2688f30

@ -356,10 +356,14 @@ class Trainer(object):
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id
]
if self.checkpoint:
epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint.epoch_id
]
else:
epochs = [epoch_id for epoch_id in range(num_epochs)]
for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()):

Loading…
Cancel
Save