|
|
|
@ -29,7 +29,8 @@ def default_event_handler(event):
|
|
|
|
|
class SGD(object):
|
|
|
|
|
"""
|
|
|
|
|
Simple SGD Trainer.
|
|
|
|
|
TODO(yuyang18): Complete comments
|
|
|
|
|
SGD Trainer combines data reader, network topolopy and update_equation together
|
|
|
|
|
to train/test a neural network.
|
|
|
|
|
|
|
|
|
|
:param update_equation: The optimizer object.
|
|
|
|
|
:type update_equation: paddle.v2.optimizer.Optimizer
|
|
|
|
@ -65,7 +66,9 @@ class SGD(object):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
:param reader:
|
|
|
|
|
:param reader: A reader that reads and yeilds data items. Usually we use a
|
|
|
|
|
batched reader to do mini-batch training.
|
|
|
|
|
:type reader: collections.Iterable
|
|
|
|
|
:param num_passes: The total train passes.
|
|
|
|
|
:param event_handler: Event handler. A method will be invoked when event
|
|
|
|
|
occurred.
|
|
|
|
@ -123,6 +126,16 @@ class SGD(object):
|
|
|
|
|
self.__gradient_machine__.finish()
|
|
|
|
|
|
|
|
|
|
def test(self, reader, feeding=None):
|
|
|
|
|
"""
|
|
|
|
|
Testing method. Will test input data.
|
|
|
|
|
|
|
|
|
|
:param reader: A reader that reads and yeilds data items.
|
|
|
|
|
:type reader: collections.Iterable
|
|
|
|
|
:param feeding: Feeding is a map of neural network input name and array
|
|
|
|
|
index that reader returns.
|
|
|
|
|
:type feeding: dict
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, feeding)
|
|
|
|
|
evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|