|
|
|
@ -65,7 +65,6 @@ class SGD(object):
|
|
|
|
|
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
|
|
|
|
|
# # In local mode, disable sparse_remote_update.
|
|
|
|
|
if is_local:
|
|
|
|
|
self.__use_sparse_updater__ = False
|
|
|
|
|
for param in self.__topology_in_proto__.parameters:
|
|
|
|
|
if param.sparse_remote_update:
|
|
|
|
|
param.sparse_remote_update = False
|
|
|
|
@ -100,11 +99,11 @@ class SGD(object):
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
if self.__is_local__:
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
parameter_updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
else:
|
|
|
|
|
updater = self.__optimizer__.create_remote_updater(
|
|
|
|
|
parameter_updater = self.__optimizer__.create_remote_updater(
|
|
|
|
|
num_passes, self.__use_sparse_updater__)
|
|
|
|
|
updater.init(self.__gradient_machine__)
|
|
|
|
|
parameter_updater.init(self.__gradient_machine__)
|
|
|
|
|
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
|
batch_evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
@ -116,26 +115,26 @@ class SGD(object):
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
|
pass_evaluator.start()
|
|
|
|
|
updater.startPass()
|
|
|
|
|
parameter_updater.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(reader()):
|
|
|
|
|
batch_evaluator.start()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id))
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
if self.__use_sparse_updater__:
|
|
|
|
|
pass_type = parameter_updater.startBatch(len(data_batch))
|
|
|
|
|
if self.__use_sparse_updater__ and not self.__is_local__:
|
|
|
|
|
self.__gradient_machine__.prefetch(feeder(data_batch))
|
|
|
|
|
updater.getParametersRemote()
|
|
|
|
|
parameter_updater.getParametersRemote()
|
|
|
|
|
self.__gradient_machine__.forwardBackward(
|
|
|
|
|
feeder(data_batch), out_args, pass_type)
|
|
|
|
|
self.__gradient_machine__.eval(pass_evaluator)
|
|
|
|
|
self.__gradient_machine__.eval(batch_evaluator)
|
|
|
|
|
for each_param in self.__gradient_machine__.getNonStaticParameters(
|
|
|
|
|
):
|
|
|
|
|
updater.update(each_param)
|
|
|
|
|
parameter_updater.update(each_param)
|
|
|
|
|
cost_sum = out_args.sum()
|
|
|
|
|
cost = cost_sum / len(data_batch)
|
|
|
|
|
updater.finishBatch(cost)
|
|
|
|
|
parameter_updater.finishBatch(cost)
|
|
|
|
|
batch_evaluator.finish()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.EndIteration(
|
|
|
|
@ -144,7 +143,7 @@ class SGD(object):
|
|
|
|
|
cost=cost,
|
|
|
|
|
evaluator=batch_evaluator))
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
parameter_updater.finishPass()
|
|
|
|
|
pass_evaluator.finish()
|
|
|
|
|
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
|
|
|
|
|
self.__gradient_machine__.finish()
|
|
|
|
|