|
|
|
@ -16,6 +16,10 @@ class BaseEvent(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompleteTrainOneBatch(BaseEvent):
|
|
|
|
|
"""
|
|
|
|
|
Event On One Batch Training Complete.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, pass_id, batch_id, cost):
|
|
|
|
|
self.pass_id = pass_id
|
|
|
|
|
self.batch_id = batch_id
|
|
|
|
@ -38,6 +42,11 @@ class ITrainer(object):
|
|
|
|
|
|
|
|
|
|
class SGDTrainer(ITrainer):
|
|
|
|
|
def __init__(self, update_equation):
|
|
|
|
|
"""
|
|
|
|
|
Simple SGD Trainer.
|
|
|
|
|
|
|
|
|
|
:param update_equation: Maybe we should give a DSL for update equation?
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(update_equation, paddle.v2.optimizer.Optimizer):
|
|
|
|
|
raise ValueError()
|
|
|
|
|
|
|
|
|
@ -52,6 +61,21 @@ class SGDTrainer(ITrainer):
|
|
|
|
|
event_handler=None,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
data_types=None):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
:param train_data_reader:
|
|
|
|
|
:param topology: Network Topology, a protobuf ModelConfig message.
|
|
|
|
|
:param parameters: The parameter pools.
|
|
|
|
|
:param num_passes: The total train passes.
|
|
|
|
|
:param test_data_reader:
|
|
|
|
|
:param event_handler: Event handler. A method will be invoked when event
|
|
|
|
|
occurred.
|
|
|
|
|
:type event_handler: (BaseEvent) => None
|
|
|
|
|
:param batch_size: Not important, will be removed after data refactor.
|
|
|
|
|
:param data_types: Not important, will be removed after data refactor.
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if event_handler is None:
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
|
|
|
|
@ -66,6 +90,9 @@ class SGDTrainer(ITrainer):
|
|
|
|
|
assert isinstance(updater, api.ParameterUpdater)
|
|
|
|
|
updater.init(gm)
|
|
|
|
|
|
|
|
|
|
gm.start()
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
|
|
|
|
|
data_types_lists = []
|
|
|
|
|
for each in topology.input_layer_names:
|
|
|
|
|
if each not in data_types:
|
|
|
|
@ -74,22 +101,11 @@ class SGDTrainer(ITrainer):
|
|
|
|
|
|
|
|
|
|
converter = DataProviderConverter(input_types=data_types_lists)
|
|
|
|
|
|
|
|
|
|
def input_reorder(func):
|
|
|
|
|
for item in func():
|
|
|
|
|
retv = []
|
|
|
|
|
for __layer_name__ in topology.input_layer_names:
|
|
|
|
|
retv.append(item[__layer_name__])
|
|
|
|
|
yield retv
|
|
|
|
|
|
|
|
|
|
gm.start()
|
|
|
|
|
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
updater.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(
|
|
|
|
|
__generator_to_batch__(
|
|
|
|
|
input_reorder(train_data_reader),
|
|
|
|
|
batch_size=batch_size)):
|
|
|
|
|
__data_reader_to_batch__(train_data_reader, batch_size,
|
|
|
|
|
topology)):
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(converter(data_batch), out_args, pass_type)
|
|
|
|
|
for each_param in gm.getParameters():
|
|
|
|
@ -108,7 +124,25 @@ class SGDTrainer(ITrainer):
|
|
|
|
|
gm.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __data_reader_to_batch__(reader, batch_size, topology):
|
|
|
|
|
"""
|
|
|
|
|
This function is not important, and will be removed when data refactored.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def input_reorder(func):
|
|
|
|
|
for item in func():
|
|
|
|
|
retv = []
|
|
|
|
|
for __layer_name__ in topology.input_layer_names:
|
|
|
|
|
retv.append(item[__layer_name__])
|
|
|
|
|
yield retv
|
|
|
|
|
|
|
|
|
|
return __generator_to_batch__(input_reorder(reader), batch_size=batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __generator_to_batch__(generator, batch_size):
|
|
|
|
|
"""
|
|
|
|
|
This function is not important, and will be removed when data refactored.
|
|
|
|
|
"""
|
|
|
|
|
ret_val = list()
|
|
|
|
|
for each_item in generator:
|
|
|
|
|
ret_val.append(each_item)
|
|
|
|
@ -139,6 +173,9 @@ def __copy_parameter_from_pool__(gm, pool):
|
|
|
|
|
|
|
|
|
|
def __check_train_args__(train_data_reader, topology, parameters,
|
|
|
|
|
test_data_reader, event_handler, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Check train function's argument types
|
|
|
|
|
"""
|
|
|
|
|
if not callable(train_data_reader) or not isinstance(train_data_reader(),
|
|
|
|
|
collections.Iterator):
|
|
|
|
|
raise ValueError('train_data_reader should be a function, '
|
|
|
|
|