|
|
|
@ -27,19 +27,13 @@ class ITrainer(object):
|
|
|
|
|
The interface of Trainer. The only exposed method is `train`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
train_data_reader,
|
|
|
|
|
cost,
|
|
|
|
|
parameters,
|
|
|
|
|
test_data_reader=None,
|
|
|
|
|
event_handler=None):
|
|
|
|
|
def train(self, reader, topology, parameters, event_handler=None):
|
|
|
|
|
"""
|
|
|
|
|
train method.
|
|
|
|
|
|
|
|
|
|
:param train_data_reader:
|
|
|
|
|
:param cost:
|
|
|
|
|
:param reader:
|
|
|
|
|
:param topology:
|
|
|
|
|
:param parameters:
|
|
|
|
|
:param test_data_reader:
|
|
|
|
|
:param event_handler:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
@ -61,26 +55,22 @@ class SGD(ITrainer):
|
|
|
|
|
self.__optimizer__ = update_equation
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
train_data_reader,
|
|
|
|
|
reader,
|
|
|
|
|
cost,
|
|
|
|
|
parameters,
|
|
|
|
|
num_passes=1,
|
|
|
|
|
test_data_reader=None,
|
|
|
|
|
event_handler=None,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
reader_dict=None):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
:param train_data_reader:
|
|
|
|
|
:param cost: cost layers, to be optimized.
|
|
|
|
|
:param reader:
|
|
|
|
|
:param topology: Network Topology, use one or more Layers to represent it.
|
|
|
|
|
:param parameters: The parameter pools.
|
|
|
|
|
:param num_passes: The total train passes.
|
|
|
|
|
:param test_data_reader:
|
|
|
|
|
:param event_handler: Event handler. A method will be invoked when event
|
|
|
|
|
occurred.
|
|
|
|
|
:type event_handler: (BaseEvent) => None
|
|
|
|
|
:param batch_size: Not important, will be removed after data refactor.
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if event_handler is None:
|
|
|
|
@ -112,9 +102,9 @@ class SGD(ITrainer):
|
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
|
pass_evaluator.start()
|
|
|
|
|
updater.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(
|
|
|
|
|
__data_reader_to_batch__(train_data_reader, batch_size,
|
|
|
|
|
topology)):
|
|
|
|
|
for batch_id, data_batch in enumerate(reader()):
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
|
|
|
|
|
batch_evaluator.start()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
@ -144,56 +134,19 @@ class SGD(ITrainer):
|
|
|
|
|
gm.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __data_reader_to_batch__(reader, batch_size, topology):
|
|
|
|
|
"""
|
|
|
|
|
This function is not important, and will be removed when data refactored.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def input_reorder(func):
|
|
|
|
|
for item in func():
|
|
|
|
|
retv = []
|
|
|
|
|
for __layer_name__ in topology.proto().input_layer_names:
|
|
|
|
|
retv.append(item[__layer_name__])
|
|
|
|
|
yield retv
|
|
|
|
|
|
|
|
|
|
return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __generator_to_batch__(generator, batch_size):
|
|
|
|
|
"""
|
|
|
|
|
This function is not important, and will be removed when data refactored.
|
|
|
|
|
"""
|
|
|
|
|
ret_val = list()
|
|
|
|
|
for each_item in generator:
|
|
|
|
|
ret_val.append(each_item)
|
|
|
|
|
if len(ret_val) == batch_size:
|
|
|
|
|
yield ret_val
|
|
|
|
|
ret_val = list()
|
|
|
|
|
if len(ret_val) != 0:
|
|
|
|
|
yield ret_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check_train_args__(train_data_reader, topology, parameters,
|
|
|
|
|
test_data_reader, event_handler, **kwargs):
|
|
|
|
|
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Check train function's argument types
|
|
|
|
|
"""
|
|
|
|
|
if not callable(train_data_reader) or not isinstance(train_data_reader(),
|
|
|
|
|
collections.Iterator):
|
|
|
|
|
raise ValueError('train_data_reader should be a function, '
|
|
|
|
|
'which can return a iterator')
|
|
|
|
|
|
|
|
|
|
if test_data_reader is not None:
|
|
|
|
|
if not callable(test_data_reader) or not isinstance(
|
|
|
|
|
test_data_reader(), collections.Iterator):
|
|
|
|
|
raise ValueError('test_data_reader should be a function, which can '
|
|
|
|
|
'return a iterator')
|
|
|
|
|
if not callable(reader) or not isinstance(reader(), collections.Iterator):
|
|
|
|
|
raise TypeError('train_data_reader should be a function, '
|
|
|
|
|
'which can return a iterator')
|
|
|
|
|
|
|
|
|
|
if not isinstance(topology, Topology):
|
|
|
|
|
raise ValueError('topology should be a model config')
|
|
|
|
|
raise TypeError('topology should be a model config')
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
|
raise ValueError('parameters should be a parameter pool')
|
|
|
|
|
raise TypeError('parameters should be a parameter pool')
|
|
|
|
|
|
|
|
|
|
if not callable(event_handler):
|
|
|
|
|
raise ValueError('event handler should be a function')
|
|
|
|
|
raise TypeError('event handler should be a function')
|
|
|
|
|