|
|
|
@ -2,6 +2,8 @@
|
|
|
|
|
Module Trainer
|
|
|
|
|
"""
|
|
|
|
|
import collections
|
|
|
|
|
import gzip
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
import py_paddle.swig_paddle as api
|
|
|
|
|
|
|
|
|
@ -42,7 +44,12 @@ class SGD(object):
|
|
|
|
|
:type extra_layers: paddle.v2.config_base.Layer
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, cost, parameters, update_equation, extra_layers=None):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
cost,
|
|
|
|
|
parameters,
|
|
|
|
|
update_equation,
|
|
|
|
|
extra_layers=None,
|
|
|
|
|
is_local=True):
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
|
raise TypeError('parameters should be parameters')
|
|
|
|
@ -55,20 +62,48 @@ class SGD(object):
|
|
|
|
|
self.__topology__ = topology
|
|
|
|
|
self.__parameters__ = parameters
|
|
|
|
|
self.__topology_in_proto__ = topology.proto()
|
|
|
|
|
self.__is_local__ = is_local
|
|
|
|
|
|
|
|
|
|
# In local mode, disable sparse_remote_update.
|
|
|
|
|
for param in self.__topology_in_proto__.parameters:
|
|
|
|
|
if param.sparse_remote_update:
|
|
|
|
|
param.sparse_remote_update = False
|
|
|
|
|
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
|
|
|
|
|
# # In local mode, disable sparse_remote_update.
|
|
|
|
|
if is_local:
|
|
|
|
|
for param in self.__topology_in_proto__.parameters:
|
|
|
|
|
if param.sparse_remote_update:
|
|
|
|
|
param.sparse_remote_update = False
|
|
|
|
|
|
|
|
|
|
self.__gm_create_mode__ = api.CREATE_MODE_NORMAL if not \
|
|
|
|
|
self.__use_sparse_updater__ else api.CREATE_MODE_SGD_SPARSE_CPU_TRAINING
|
|
|
|
|
self.__data_types__ = topology.data_type()
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
|
|
|
|
|
self.__topology_in_proto__, self.__gm_create_mode__,
|
|
|
|
|
self.__optimizer__.enable_types())
|
|
|
|
|
assert isinstance(gm, api.GradientMachine)
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__gradient_machine__.randParameters()
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
self.__parameters__.append_gradient_machine(gm)
|
|
|
|
|
self.__parameter_updater__ = None
|
|
|
|
|
|
|
|
|
|
def __use_remote_sparse_updater__(self):
|
|
|
|
|
return self.__use_sparse_updater__ and not self.__is_local__
|
|
|
|
|
|
|
|
|
|
def __prepare_parameter__(self, in_args):
|
|
|
|
|
"""
|
|
|
|
|
prepare parameter before forward backward.
|
|
|
|
|
1. When use remote sparse updater, parameters should be got
|
|
|
|
|
from ps according to input arguments.
|
|
|
|
|
:param in_args: input arguments of this batch.
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if self.__use_remote_sparse_updater__():
|
|
|
|
|
self.__gradient_machine__.prefetch(in_args)
|
|
|
|
|
self.__parameter_updater__.getParametersRemote()
|
|
|
|
|
|
|
|
|
|
def save_parameter_to_tar(self, f):
|
|
|
|
|
self.__parameter_updater__.catchUpWith()
|
|
|
|
|
self.__parameter_updater__.apply()
|
|
|
|
|
self.__parameter_updater__.getParametersRemote(True, True)
|
|
|
|
|
self.__parameters__.to_tar(f)
|
|
|
|
|
self.__parameter_updater__.restore()
|
|
|
|
|
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
|
|
|
|
|
"""
|
|
|
|
@ -90,8 +125,9 @@ class SGD(object):
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
updater.init(self.__gradient_machine__)
|
|
|
|
|
self.__parameter_updater__ = self.__optimizer__.create_updater(
|
|
|
|
|
self.__is_local__, num_passes, self.__use_sparse_updater__)
|
|
|
|
|
self.__parameter_updater__.init(self.__gradient_machine__)
|
|
|
|
|
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
|
batch_evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
@ -103,23 +139,26 @@ class SGD(object):
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
|
pass_evaluator.start()
|
|
|
|
|
updater.startPass()
|
|
|
|
|
self.__parameter_updater__.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(reader()):
|
|
|
|
|
batch_evaluator.start()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id))
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
self.__gradient_machine__.forwardBackward(
|
|
|
|
|
feeder(data_batch), out_args, pass_type)
|
|
|
|
|
pass_type = self.__parameter_updater__.startBatch(
|
|
|
|
|
len(data_batch))
|
|
|
|
|
in_args = feeder(data_batch)
|
|
|
|
|
self.__prepare_parameter__(in_args)
|
|
|
|
|
self.__gradient_machine__.forwardBackward(in_args, out_args,
|
|
|
|
|
pass_type)
|
|
|
|
|
self.__gradient_machine__.eval(pass_evaluator)
|
|
|
|
|
self.__gradient_machine__.eval(batch_evaluator)
|
|
|
|
|
for each_param in self.__gradient_machine__.getNonStaticParameters(
|
|
|
|
|
):
|
|
|
|
|
updater.update(each_param)
|
|
|
|
|
self.__parameter_updater__.update(each_param)
|
|
|
|
|
cost_sum = out_args.sum()
|
|
|
|
|
cost = cost_sum / len(data_batch)
|
|
|
|
|
updater.finishBatch(cost)
|
|
|
|
|
self.__parameter_updater__.finishBatch(cost)
|
|
|
|
|
batch_evaluator.finish()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.EndIteration(
|
|
|
|
@ -128,7 +167,7 @@ class SGD(object):
|
|
|
|
|
cost=cost,
|
|
|
|
|
evaluator=batch_evaluator))
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
self.__parameter_updater__.finishPass()
|
|
|
|
|
pass_evaluator.finish()
|
|
|
|
|
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
|
|
|
|
|
self.__gradient_machine__.finish()
|
|
|
|
@ -152,8 +191,9 @@ class SGD(object):
|
|
|
|
|
num_samples = 0.0
|
|
|
|
|
for data_batch in reader():
|
|
|
|
|
num_samples += len(data_batch)
|
|
|
|
|
self.__gradient_machine__.forward(
|
|
|
|
|
feeder(data_batch), out_args, api.PASS_TEST)
|
|
|
|
|
in_args = feeder(data_batch)
|
|
|
|
|
self.__prepare_parameter__(in_args)
|
|
|
|
|
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
|
|
|
|
|
total_cost += out_args.sum()
|
|
|
|
|
self.__gradient_machine__.eval(evaluator)
|
|
|
|
|
|
|
|
|
|