|
|
|
@ -42,7 +42,7 @@ class SGD(object):
|
|
|
|
|
:type extra_layers: paddle.v2.config_base.Layer
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, cost, parameters, update_equation, extra_layers=None):
|
|
|
|
|
def __init__(self, cost, parameters, update_equation, extra_layers=None, is_local=True):
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
|
raise TypeError('parameters should be parameters')
|
|
|
|
@ -55,15 +55,21 @@ class SGD(object):
|
|
|
|
|
self.__topology__ = topology
|
|
|
|
|
self.__parameters__ = parameters
|
|
|
|
|
self.__topology_in_proto__ = topology.proto()
|
|
|
|
|
|
|
|
|
|
# In local mode, disable sparse_remote_update.
|
|
|
|
|
for param in self.__topology_in_proto__.parameters:
|
|
|
|
|
if param.sparse_remote_update:
|
|
|
|
|
param.sparse_remote_update = False
|
|
|
|
|
|
|
|
|
|
self.__is_local__ = is_local
|
|
|
|
|
|
|
|
|
|
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
|
|
|
|
|
# # In local mode, disable sparse_remote_update.
|
|
|
|
|
if is_local:
|
|
|
|
|
self.__use_sparse_updater__ = False
|
|
|
|
|
for param in self.__topology_in_proto__.parameters:
|
|
|
|
|
if param.sparse_remote_update:
|
|
|
|
|
param.sparse_remote_update = False
|
|
|
|
|
|
|
|
|
|
self.__gm_create_mode__ = api.CREATE_MODE_NORMAL if not \
|
|
|
|
|
self.__use_sparse_updater__ else api.CREATE_MODE_SGD_SPARSE_CPU_TRAINING
|
|
|
|
|
self.__data_types__ = topology.data_type()
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
|
|
|
|
|
self.__topology_in_proto__, self.__gm_create_mode__,
|
|
|
|
|
self.__optimizer__.enable_types())
|
|
|
|
|
assert isinstance(gm, api.GradientMachine)
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
@ -88,7 +94,10 @@ class SGD(object):
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
if self.__is_local__:
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
else:
|
|
|
|
|
updater = self.__optimizer__.create_remote_updater(num_passes)
|
|
|
|
|
updater.init(self.__gradient_machine__)
|
|
|
|
|
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
@ -108,6 +117,9 @@ class SGD(object):
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id))
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
if self.__use_sparse_updater__:
|
|
|
|
|
self.__gradient_machine__.prefetch(feeder(data_batch))
|
|
|
|
|
updater.getParametersRemote()
|
|
|
|
|
self.__gradient_machine__.forwardBackward(
|
|
|
|
|
feeder(data_batch), out_args, pass_type)
|
|
|
|
|
self.__gradient_machine__.eval(pass_evaluator)
|
|
|
|
|