|
|
|
@ -8,7 +8,7 @@ from . import event as v2_event
|
|
|
|
|
from . import optimizer as v2_optimizer
|
|
|
|
|
from . import parameters as v2_parameters
|
|
|
|
|
|
|
|
|
|
__all__ = ['ITrainer', 'SGD']
|
|
|
|
|
__all__ = ['SGD']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_event_handler(event):
|
|
|
|
@ -22,26 +22,7 @@ def default_event_handler(event):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ITrainer(object):
|
|
|
|
|
"""
|
|
|
|
|
The interface of Trainer. The only exposed method is `train`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def train(self, reader, topology, parameters, event_handler=None):
|
|
|
|
|
"""
|
|
|
|
|
train method.
|
|
|
|
|
|
|
|
|
|
:param reader:
|
|
|
|
|
:param topology:
|
|
|
|
|
:param parameters:
|
|
|
|
|
:param event_handler:
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SGD(ITrainer):
|
|
|
|
|
class SGD():
|
|
|
|
|
def __init__(self, cost, parameters, update_equation):
|
|
|
|
|
"""
|
|
|
|
|
Simple SGD Trainer.
|
|
|
|
|