|
|
|
@ -78,12 +78,24 @@ class SGD(object):
|
|
|
|
|
assert isinstance(gm, api.GradientMachine)
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__gradient_machine__.randParameters()
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
self.__parameters__.append_gradient_machine(gm)
|
|
|
|
|
self.__parameter_updater__ = None
|
|
|
|
|
|
|
|
|
|
def use_remote_sparse_updater(self):
|
|
|
|
|
def __use_remote_sparse_updater__(self):
|
|
|
|
|
return self.__use_sparse_updater__ and not self.__is_local__
|
|
|
|
|
|
|
|
|
|
def __prepare_parameter__(self, in_args):
|
|
|
|
|
"""
|
|
|
|
|
prepare parameter before forward backward.
|
|
|
|
|
1. When use remote sparse updater, parameters should be got
|
|
|
|
|
from ps according to input arguments.
|
|
|
|
|
:param in_args: input arguments of this batch.
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if self.__use_remote_sparse_updater__():
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
self.__parameter_updater__.getParametersRemote()
|
|
|
|
|
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
@ -125,9 +137,7 @@ class SGD(object):
|
|
|
|
|
pass_type = self.__parameter_updater__.startBatch(
|
|
|
|
|
len(data_batch))
|
|
|
|
|
in_args = feeder(data_batch)
|
|
|
|
|
if self.use_remote_sparse_updater():
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
self.__parameter_updater__.getParametersRemote()
|
|
|
|
|
self.__prepare_parameter__(in_args)
|
|
|
|
|
self.__gradient_machine__.forwardBackward(in_args, out_args,
|
|
|
|
|
pass_type)
|
|
|
|
|
self.__gradient_machine__.eval(pass_evaluator)
|
|
|
|
@ -161,9 +171,7 @@ class SGD(object):
|
|
|
|
|
for data_batch in reader():
|
|
|
|
|
num_samples += len(data_batch)
|
|
|
|
|
in_args = feeder(data_batch)
|
|
|
|
|
if self.use_remote_sparse_updater():
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
self.__parameter_updater__.getParametersRemote()
|
|
|
|
|
self.__prepare_parameter__(in_args)
|
|
|
|
|
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
|
|
|
|
|
total_cost += out_args.sum()
|
|
|
|
|
self.__gradient_machine__.eval(evaluator)
|
|
|
|
|