|
|
|
|
@ -42,25 +42,35 @@ class ITrainer(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SGD(ITrainer):
|
|
|
|
|
def __init__(self, update_equation):
|
|
|
|
|
def __init__(self, cost, parameters, update_equation):
|
|
|
|
|
"""
|
|
|
|
|
Simple SGD Trainer.
|
|
|
|
|
|
|
|
|
|
:param update_equation: The optimizer object.
|
|
|
|
|
:type update_equation: v2_optimizer.Optimizer
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
|
raise TypeError('parameters should be parameters')
|
|
|
|
|
|
|
|
|
|
if not isinstance(update_equation, v2_optimizer.Optimizer):
|
|
|
|
|
raise ValueError("update equation parameter must be "
|
|
|
|
|
"paddle.v2.optimizer.Optimizer")
|
|
|
|
|
raise TypeError("update equation parameter must be "
|
|
|
|
|
"paddle.v2.optimizer.Optimizer")
|
|
|
|
|
topology = Topology(cost)
|
|
|
|
|
self.__optimizer__ = update_equation
|
|
|
|
|
self.__topology__ = topology
|
|
|
|
|
self.__parameters__ = parameters
|
|
|
|
|
self.__topology_in_proto__ = topology.proto()
|
|
|
|
|
self.__data_types__ = topology.data_type()
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
|
|
|
|
|
self.__optimizer__.enable_types())
|
|
|
|
|
assert isinstance(gm, api.GradientMachine)
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
self.__gradient_machine__ = gm
|
|
|
|
|
self.__gradient_machine__.randParameters()
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
reader,
|
|
|
|
|
cost,
|
|
|
|
|
parameters,
|
|
|
|
|
num_passes=1,
|
|
|
|
|
event_handler=None,
|
|
|
|
|
reader_dict=None):
|
|
|
|
|
def train(self, reader, num_passes=1, event_handler=None, reader_dict=None):
|
|
|
|
|
"""
|
|
|
|
|
Training method. Will train num_passes of input data.
|
|
|
|
|
|
|
|
|
|
@ -76,27 +86,22 @@ class SGD(ITrainer):
|
|
|
|
|
if event_handler is None:
|
|
|
|
|
event_handler = default_event_handler
|
|
|
|
|
|
|
|
|
|
topology = Topology(cost)
|
|
|
|
|
if reader_dict is None:
|
|
|
|
|
reader_dict = self.default_reader_dict()
|
|
|
|
|
|
|
|
|
|
__check_train_args__(**locals())
|
|
|
|
|
|
|
|
|
|
gm = api.GradientMachine.createFromConfigProto(
|
|
|
|
|
topology.proto(), api.CREATE_MODE_NORMAL,
|
|
|
|
|
self.__optimizer__.enable_types())
|
|
|
|
|
assert isinstance(gm, api.GradientMachine)
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
gm.randParameters()
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
updater.init(gm)
|
|
|
|
|
updater.init(self.__gradient_machine__)
|
|
|
|
|
|
|
|
|
|
gm.start()
|
|
|
|
|
batch_evaluator = gm.makeEvaluator()
|
|
|
|
|
self.__gradient_machine__.start()
|
|
|
|
|
batch_evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
|
assert isinstance(batch_evaluator, api.Evaluator)
|
|
|
|
|
pass_evaluator = gm.makeEvaluator()
|
|
|
|
|
pass_evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
|
assert isinstance(pass_evaluator, api.Evaluator)
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(topology.data_type(), reader_dict)
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, reader_dict)
|
|
|
|
|
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
|
@ -104,16 +109,18 @@ class SGD(ITrainer):
|
|
|
|
|
updater.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(reader()):
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
|
|
|
|
|
self.__gradient_machine__.forwardBackward(
|
|
|
|
|
feeder(data_batch), out_args, pass_type)
|
|
|
|
|
batch_evaluator.start()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id))
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
|
|
|
|
|
gm.eval(pass_evaluator)
|
|
|
|
|
gm.eval(batch_evaluator)
|
|
|
|
|
for each_param in gm.getParameters():
|
|
|
|
|
self.__gradient_machine__.forwardBackward(
|
|
|
|
|
feeder(data_batch), out_args, pass_type)
|
|
|
|
|
self.__gradient_machine__.eval(pass_evaluator)
|
|
|
|
|
self.__gradient_machine__.eval(batch_evaluator)
|
|
|
|
|
for each_param in self.__gradient_machine__.getParameters():
|
|
|
|
|
updater.update(each_param)
|
|
|
|
|
# Get cost. We use numpy to calculate total cost for this batch.
|
|
|
|
|
cost_vec = out_args.getSlotValue(0)
|
|
|
|
|
@ -131,22 +138,37 @@ class SGD(ITrainer):
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
pass_evaluator.finish()
|
|
|
|
|
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
|
|
|
|
|
gm.finish()
|
|
|
|
|
self.__gradient_machine__.finish()
|
|
|
|
|
|
|
|
|
|
def default_reader_dict(self):
|
|
|
|
|
reader_dict = dict()
|
|
|
|
|
for i, tp in enumerate(self.__data_types__):
|
|
|
|
|
reader_dict[tp[0]] = i
|
|
|
|
|
return reader_dict
|
|
|
|
|
|
|
|
|
|
def test(self, reader, reader_dict=None):
|
|
|
|
|
if reader_dict is None:
|
|
|
|
|
reader_dict = self.default_reader_dict()
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(self.__data_types__, reader_dict)
|
|
|
|
|
evaluator = self.__gradient_machine__.makeEvaluator()
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
evaluator.start()
|
|
|
|
|
for data_batch in reader():
|
|
|
|
|
self.__gradient_machine__.forward(
|
|
|
|
|
feeder(data_batch), out_args, api.PASS_TEST)
|
|
|
|
|
self.__gradient_machine__.eval(evaluator)
|
|
|
|
|
|
|
|
|
|
evaluator.finish()
|
|
|
|
|
return v2_event.TestResult(evaluator=evaluator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check_train_args__(reader, topology, parameters, event_handler, **kwargs):
|
|
|
|
|
def __check_train_args__(reader, event_handler, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Check train function's argument types
|
|
|
|
|
"""
|
|
|
|
|
if not callable(reader) or not isinstance(reader(), collections.Iterator):
|
|
|
|
|
raise TypeError('train_data_reader should be a function, '
|
|
|
|
|
'which can return a iterator')
|
|
|
|
|
|
|
|
|
|
if not isinstance(topology, Topology):
|
|
|
|
|
raise TypeError('topology should be a model config')
|
|
|
|
|
|
|
|
|
|
if not isinstance(parameters, v2_parameters.Parameters):
|
|
|
|
|
raise TypeError('parameters should be a parameter pool')
|
|
|
|
|
|
|
|
|
|
if not callable(event_handler):
|
|
|
|
|
raise TypeError('event handler should be a function')
|
|
|
|
|
|