|
|
|
@ -79,6 +79,10 @@ class SGD(object):
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__gradient_machine__.randParameters()
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
self.__parameter_updater__ = None
|
|
|
|
|
|
|
|
|
|
def use_remote_sparse_updater(self):
|
|
|
|
|
return self.__use_sparse_updater__ and not self.__is_local__
|
|
|
|
|
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
|
|
|
|
|
"""
|
|
|
|
@ -103,6 +107,7 @@ class SGD(object):
|
|
|
|
|
else:
|
|
|
|
|
parameter_updater = self.__optimizer__.create_remote_updater(
|
|
|
|
|
num_passes, self.__use_sparse_updater__)
|
|
|
|
|
self.__parameter_updater__ = parameter_updater
|
|
|
|
|
parameter_updater.init(self.__gradient_machine__)
|
|
|
|
|
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
@ -122,11 +127,12 @@ class SGD(object):
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id))
|
|
|
|
|
pass_type = parameter_updater.startBatch(len(data_batch))
|
|
|
|
|
if self.__use_sparse_updater__ and not self.__is_local__:
|
|
|
|
|
self.__gradient_machine__.prefetch(feeder(data_batch))
|
|
|
|
|
in_args = feeder(data_batch)
|
|
|
|
|
if self.use_remote_sparse_updater():
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
parameter_updater.getParametersRemote()
|
|
|
|
|
self.__gradient_machine__.forwardBackward(
|
|
|
|
|
feeder(data_batch), out_args, pass_type)
|
|
|
|
|
self.__gradient_machine__.forwardBackward(in_args, out_args,
|
|
|
|
|
pass_type)
|
|
|
|
|
self.__gradient_machine__.eval(pass_evaluator)
|
|
|
|
|
self.__gradient_machine__.eval(batch_evaluator)
|
|
|
|
|
for each_param in self.__gradient_machine__.getNonStaticParameters(
|
|
|
|
@ -157,8 +163,11 @@ class SGD(object):
|
|
|
|
|
num_samples = 0.0
|
|
|
|
|
for data_batch in reader():
|
|
|
|
|
num_samples += len(data_batch)
|
|
|
|
|
self.__gradient_machine__.forward(
|
|
|
|
|
feeder(data_batch), out_args, api.PASS_TEST)
|
|
|
|
|
in_args = feeder(data_batch)
|
|
|
|
|
if self.use_remote_sparse_updater():
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
self.__parameter_updater__.getParametersRemote()
|
|
|
|
|
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
|
|
|
|
|
total_cost += out_args.sum()
|
|
|
|
|
self.__gradient_machine__.eval(evaluator)
|
|
|
|
|
|
|
|
|
|