|
|
|
|
@ -100,6 +100,7 @@ class Trainer(object):
|
|
|
|
|
param_path=None,
|
|
|
|
|
place=None,
|
|
|
|
|
parallel=False):
|
|
|
|
|
self.__stop = False
|
|
|
|
|
self.parallel = parallel
|
|
|
|
|
# 1. we need to generate a framework.Program by calling
|
|
|
|
|
# program_func. Reference: fluid.program_guard in
|
|
|
|
|
@ -210,6 +211,12 @@ class Trainer(object):
|
|
|
|
|
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
|
"""
|
|
|
|
|
stop training
|
|
|
|
|
"""
|
|
|
|
|
self.__stop = True
|
|
|
|
|
|
|
|
|
|
def train(self, num_epochs, event_handler, reader=None, feed_order=None):
|
|
|
|
|
"""
|
|
|
|
|
Train the model.
|
|
|
|
|
@ -289,6 +296,8 @@ class Trainer(object):
|
|
|
|
|
for epoch_id in range(num_epochs):
|
|
|
|
|
event_handler(BeginEpochEvent(epoch_id))
|
|
|
|
|
for step_id, data in enumerate(reader()):
|
|
|
|
|
if self.__stop:
|
|
|
|
|
return
|
|
|
|
|
begin_event = BeginStepEvent(epoch_id, step_id)
|
|
|
|
|
event_handler(begin_event)
|
|
|
|
|
if begin_event.fetch_metrics:
|
|
|
|
|
@ -327,9 +336,7 @@ class Trainer(object):
|
|
|
|
|
feeder = data_feeder.DataFeeder(
|
|
|
|
|
feed_list=feed_var_list, place=self.place)
|
|
|
|
|
reader = feeder.decorate_reader(reader, multi_devices=True)
|
|
|
|
|
for epoch_id in range(num_epochs):
|
|
|
|
|
self._train_by_any_executor(event_handler, pe, num_epochs,
|
|
|
|
|
reader)
|
|
|
|
|
self._train_by_any_executor(event_handler, pe, num_epochs, reader)
|
|
|
|
|
|
|
|
|
|
def _get_parallel_executor(self):
|
|
|
|
|
return getattr(self, 'parallel_executor', None)
|
|
|
|
|
|