|
|
|
@ -28,19 +28,13 @@ class ITrainer(object):
|
|
|
|
|
The interface of Trainer. The only exposed method is `train`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
train_reader_creator,
|
|
|
|
|
topology,
|
|
|
|
|
parameters,
|
|
|
|
|
test_data_reader=None,
|
|
|
|
|
event_handler=None):
|
|
|
|
|
def train(self, reader, topology, parameters, event_handler=None):
|
|
|
|
|
"""
|
|
|
|
|
train method.
|
|
|
|
|
|
|
|
|
|
:param train_reader_creator:
|
|
|
|
|
:param reader:
|
|
|
|
|
:param topology:
|
|
|
|
|
:param parameters:
|
|
|
|
|
:param test_data_reader:
|
|
|
|
|
:param event_handler:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
@ -62,7 +56,7 @@ class SGD(ITrainer):
|
|
|
|
|
self.__optimizer__ = update_equation
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
train_reader,
|
|
|
|
|
reader,
|
|
|
|
|
topology,
|
|
|
|
|
parameters,
|
|
|
|
|
num_passes=1,
|
|
|
|
@ -72,7 +66,7 @@ class SGD(ITrainer):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
:param train_reader:
|
|
|
|
|
:param reader:
|
|
|
|
|
:param topology: Network Topology, use one or more Layers to represent it.
|
|
|
|
|
:param parameters: The parameter pools.
|
|
|
|
|
:param num_passes: The total train passes.
|
|
|
|
@ -104,7 +98,7 @@ class SGD(ITrainer):
|
|
|
|
|
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
updater.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(train_reader()):
|
|
|
|
|
for batch_id, data_batch in enumerate(reader()):
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
|
|
|
|
|
for each_param in gm.getParameters():
|
|
|
|
@ -122,13 +116,11 @@ class SGD(ITrainer):
|
|
|
|
|
gm.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check_train_args__(train_reader, topology, parameters, event_handler,
|
|
|
|
|
**kwargs):
|
|
|
|
|
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Check train function's argument types
|
|
|
|
|
"""
|
|
|
|
|
if not callable(train_reader) or not isinstance(train_reader(),
|
|
|
|
|
collections.Iterator):
|
|
|
|
|
if not callable(reader) or not isinstance(reader(), collections.Iterator):
|
|
|
|
|
raise TypeError('train_data_reader should be a function, '
|
|
|
|
|
'which can return a iterator')
|
|
|
|
|
|
|
|
|
|