|
|
|
@ -61,7 +61,7 @@ class SGD(object):
|
|
|
|
|
self.__gradient_machine__.randParameters()
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
@ -70,14 +70,13 @@ class SGD(object):
|
|
|
|
|
:param event_handler: Event handler. A method will be invoked when event
|
|
|
|
|
occurred.
|
|
|
|
|
:type event_handler: (BaseEvent) => None
|
|
|
|
|
:param feeding: Feeding is a map of neural network input name and array
|
|
|
|
|
index that reader returns.
|
|
|
|
|
:type feeding: dict
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if event_handler is None:
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
|
|
|
|
|
if reader_dict is None:
|
|
|
|
|
reader_dict = self.default_reader_dict()
|
|
|
|
|
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
@ -89,9 +88,7 @@ class SGD(object):
|
|
|
|
|
pass_evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
|
assert isinstance(pass_evaluator, api.Evaluator)
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, reader_dict)
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, feeding)
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
|
pass_evaluator.start()
|
|
|
|
@ -125,17 +122,8 @@ class SGD(object):
|
|
|
|
|
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
|
|
|
|
|
self.__gradient_machine__.finish()
|
|
|
|
|
|
|
|
|
|
def default_reader_dict(self):
|
|
|
|
|
reader_dict = dict()
|
|
|
|
|
for i, tp in enumerate(self.__data_types__):
|
|
|
|
|
reader_dict[tp[0]] = i
|
|
|
|
|
return reader_dict
|
|
|
|
|
|
|
|
|
|
def test(self, reader, reader_dict=None):
|
|
|
|
|
if reader_dict is None:
|
|
|
|
|
reader_dict = self.default_reader_dict()
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, reader_dict)
|
|
|
|
|
def test(self, reader, feeding=None):
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, feeding)
|
|
|
|
|
evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
evaluator.start()
|
|
|
|
|