|
|
|
@ -62,7 +62,7 @@ class SGD(ITrainer):
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
train_data_reader,
|
|
|
|
|
topology,
|
|
|
|
|
cost,
|
|
|
|
|
parameters,
|
|
|
|
|
num_passes=1,
|
|
|
|
|
test_data_reader=None,
|
|
|
|
@ -73,7 +73,7 @@ class SGD(ITrainer):
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
:param train_data_reader:
|
|
|
|
|
:param topology: cost layers, use one or more Layers to represent it.
|
|
|
|
|
:param cost: cost layers, to be optimized.
|
|
|
|
|
:param parameters: The parameter pools.
|
|
|
|
|
:param num_passes: The total train passes.
|
|
|
|
|
:param test_data_reader:
|
|
|
|
@ -86,7 +86,7 @@ class SGD(ITrainer):
|
|
|
|
|
if event_handler is None:
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
|
|
|
|
|
topology = Topology(topology)
|
|
|
|
|
topology = Topology(cost)
|
|
|
|
|
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|