|
|
|
@ -49,12 +49,14 @@ class BeginStepEvent(object):
|
|
|
|
|
def __init__(self, epoch_id, step_id):
|
|
|
|
|
self.epoch = epoch_id
|
|
|
|
|
self.step = step_id
|
|
|
|
|
self.fetch_metrics = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndStepEvent(object):
|
|
|
|
|
def __init__(self, epoch_id, step_id):
|
|
|
|
|
def __init__(self, epoch_id, step_id, metrics):
|
|
|
|
|
self.epoch = epoch_id
|
|
|
|
|
self.step = step_id
|
|
|
|
|
self.metrics = metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_get_place(place):
|
|
|
|
@ -259,12 +261,24 @@ class Trainer(object):
|
|
|
|
|
feeder = data_feeder.DataFeeder(
|
|
|
|
|
feed_list=feed_var_list, place=self.place)
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
reader = feeder.decorate_reader(reader, multi_devices=False)
|
|
|
|
|
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
|
|
|
|
|
|
|
|
|
|
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
|
|
|
|
|
for epoch_id in range(num_epochs):
|
|
|
|
|
event_handler(BeginEpochEvent(epoch_id))
|
|
|
|
|
for step_id, data in enumerate(reader()):
|
|
|
|
|
event_handler(BeginStepEvent(epoch_id, step_id))
|
|
|
|
|
exe.run(feed=feeder.feed(data), fetch_list=[])
|
|
|
|
|
event_handler(EndStepEvent(epoch_id, step_id))
|
|
|
|
|
begin_event = BeginStepEvent(epoch_id, step_id)
|
|
|
|
|
event_handler(begin_event)
|
|
|
|
|
if begin_event.fetch_metrics:
|
|
|
|
|
metrics = exe.run(feed=data,
|
|
|
|
|
fetch_list=[
|
|
|
|
|
var.name
|
|
|
|
|
for var in self.train_func_outputs
|
|
|
|
|
])
|
|
|
|
|
else:
|
|
|
|
|
metrics = exe.run(feed=data, fetch_list=[])
|
|
|
|
|
event_handler(EndStepEvent(epoch_id, step_id, metrics))
|
|
|
|
|
event_handler(EndEpochEvent(epoch_id))
|
|
|
|
|
|
|
|
|
|
def _test_by_executor(self, reader, feed_order, fetch_list):
|
|
|
|
@ -293,17 +307,8 @@ class Trainer(object):
|
|
|
|
|
feed_list=feed_var_list, place=self.place)
|
|
|
|
|
reader = feeder.decorate_reader(reader, multi_devices=True)
|
|
|
|
|
for epoch_id in range(num_epochs):
|
|
|
|
|
event_handler(BeginEpochEvent(epoch_id=epoch_id))
|
|
|
|
|
for step_id, data in enumerate(reader()):
|
|
|
|
|
event_handler(
|
|
|
|
|
BeginStepEvent(
|
|
|
|
|
epoch_id=epoch_id, step_id=step_id))
|
|
|
|
|
pe.run(feed=data, fetch_list=[])
|
|
|
|
|
event_handler(
|
|
|
|
|
EndStepEvent(
|
|
|
|
|
epoch_id=epoch_id, step_id=step_id))
|
|
|
|
|
|
|
|
|
|
event_handler(EndEpochEvent(epoch_id=epoch_id))
|
|
|
|
|
self._train_by_any_executor(event_handler, pe, num_epochs,
|
|
|
|
|
reader)
|
|
|
|
|
|
|
|
|
|
def _get_parallel_executor(self):
|
|
|
|
|
return getattr(self, 'parallel_executor', None)
|
|
|
|
|