You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/v2/trainer.py

226 lines
9.0 KiB

"""
8 years ago
Module Trainer
"""
import collections
8 years ago
from topology import Topology
from . import event as v2_event
from . import optimizer as v2_optimizer
from . import parameters as v2_parameters
__all__ = ['SGD']
def default_event_handler(event):
8 years ago
"""
Default event handler. It will print some log and save mode.
TODO(yuyang18): Complete it!
:param event:
:return:
"""
pass
class SGD(object):
"""
Simple SGD Trainer.
SGD Trainer combines data reader, network topolopy and update_equation together
to train/test a neural network.
:param cost: Target cost that neural network should be optimized.
:type cost: paddle.v2.config_base.Layer
:param parameters: The parameters dictionary.
:type parameters: paddle.v2.parameters.Parameters
:param update_equation: The optimizer object.
:type update_equation: paddle.v2.optimizer.Optimizer
8 years ago
:param extra_layers: Some layers in the neural network graph are not
in the path of cost layer.
:type extra_layers: paddle.v2.config_base.Layer
:param is_local: Whether trainning locally
:type is_local: bool
:param pserver_spec: pserver location, eg: localhost:3000,
if use_etcd is true, pserver_spec indicates
the etcd endpoints, eg: http://127.0.0.1:2379
:type pserver_spec: string
:param use_etcd: Whether using etcd pserver.
:param use_etcd: bool
"""
def __init__(self,
cost,
parameters,
update_equation,
extra_layers=None,
is_local=True,
pserver_spec=None,
use_etcd=True):
if not isinstance(parameters, v2_parameters.Parameters):
raise TypeError('parameters should be parameters')
if not isinstance(update_equation, v2_optimizer.Optimizer):
raise TypeError("update equation parameter must be "
"paddle.v2.optimizer.Optimizer")
import py_paddle.swig_paddle as api
topology = Topology(cost, extra_layers=extra_layers)
self.__optimizer__ = update_equation
self.__topology__ = topology
self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto()
self.__is_local__ = is_local
self.__pserver_spec__ = pserver_spec
self.__use_etcd__ = use_etcd
self.__use_sparse_updater__ = self.__topology__.use_sparse_updater()
# # In local mode, disable sparse_remote_update.
if is_local:
for param in self.__topology_in_proto__.parameters:
if param.sparse_remote_update:
param.sparse_remote_update = False
self.__gm_create_mode__ = api.CREATE_MODE_NORMAL if not \
self.__use_sparse_updater__ else api.CREATE_MODE_SGD_SPARSE_CPU_TRAINING
self.__data_types__ = topology.data_type()
gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, self.__gm_create_mode__,
self.__optimizer__.enable_types())
assert isinstance(gm, api.GradientMachine)
self.__gradient_machine__ = gm
self.__gradient_machine__.randParameters()
8 years ago
self.__parameters__.append_gradient_machine(gm)
self.__parameter_updater__ = None
8 years ago
def __use_remote_sparse_updater__(self):
return self.__use_sparse_updater__ and not self.__is_local__
8 years ago
def __prepare_parameter__(self, in_args):
"""
prepare parameter before forward backward.
1. When use remote sparse updater, parameters should be got
from ps according to input arguments.
:param in_args: input arguments of this batch.
:return:
"""
if self.__use_remote_sparse_updater__():
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
def save_parameter_to_tar(self, f):
self.__parameter_updater__.catchUpWith()
self.__parameter_updater__.apply()
self.__parameter_updater__.getParametersRemote(True, True)
self.__parameters__.to_tar(f)
self.__parameter_updater__.restore()
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
"""
Training method. Will train num_passes of input data.
:param reader: A reader that reads and yeilds data items. Usually we use a
batched reader to do mini-batch training.
:type reader: collections.Iterable
:param num_passes: The total train passes.
:param event_handler: Event handler. A method will be invoked when event
occurred.
:type event_handler: (BaseEvent) => None
:param feeding: Feeding is a map of neural network input name and array
index that reader returns.
:type feeding: dict|list
:return:
"""
import py_paddle.swig_paddle as api
from data_feeder import DataFeeder
if event_handler is None:
event_handler = default_event_handler
__check_train_args__(**locals())
self.__parameter_updater__ = self.__optimizer__.create_updater(
self.__is_local__, num_passes, self.__use_sparse_updater__,
self.__pserver_spec__, self.__use_etcd__)
self.__parameter_updater__.init(self.__gradient_machine__)
self.__gradient_machine__.start()
batch_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(batch_evaluator, api.Evaluator)
pass_evaluator = self.__gradient_machine__.makeEvaluator()
assert isinstance(pass_evaluator, api.Evaluator)
out_args = api.Arguments.createArguments(0)
feeder = DataFeeder(self.__data_types__, feeding)
for pass_id in xrange(num_passes):
event_handler(v2_event.BeginPass(pass_id))
pass_evaluator.start()
self.__parameter_updater__.startPass()
for batch_id, data_batch in enumerate(reader()):
batch_evaluator.start()
event_handler(
v2_event.BeginIteration(
pass_id=pass_id, batch_id=batch_id))
pass_type = self.__parameter_updater__.startBatch(
len(data_batch))
in_args = feeder(data_batch)
8 years ago
self.__prepare_parameter__(in_args)
self.__gradient_machine__.forwardBackward(in_args, out_args,
pass_type)
self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator)
for each_param in self.__gradient_machine__.getNonStaticParameters(
):
self.__parameter_updater__.update(each_param)
8 years ago
cost_sum = out_args.sum()
cost = cost_sum / len(data_batch)
event_handler(
8 years ago
v2_event.EndIteration(
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
evaluator=batch_evaluator))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
self.__gradient_machine__.finish()
def test(self, reader, feeding=None):
"""
Testing method. Will test input data.
:param reader: A reader that reads and yeilds data items.
:type reader: collections.Iterable
:param feeding: Feeding is a map of neural network input name and array
index that reader returns.
:type feeding: dict
:return:
"""
import py_paddle.swig_paddle as api
from data_feeder import DataFeeder
feeder = DataFeeder(self.__data_types__, feeding)
evaluator = self.__gradient_machine__.makeEvaluator()
out_args = api.Arguments.createArguments(0)
evaluator.start()
total_cost = 0
num_samples = 0.0
for data_batch in reader():
num_samples += len(data_batch)
in_args = feeder(data_batch)
8 years ago
self.__prepare_parameter__(in_args)
self.__gradient_machine__.forward(in_args, out_args, api.PASS_TEST)
8 years ago
total_cost += out_args.sum()
self.__gradient_machine__.eval(evaluator)
evaluator.finish()
return v2_event.TestResult(
evaluator=evaluator, cost=total_cost / num_samples)
def __check_train_args__(reader, event_handler, **kwargs):
"""
Check train function's argument types
"""
if not callable(reader) or not isinstance(reader(), collections.Iterator):
raise TypeError('train_data_reader should be a function, '
'which can return a iterator')
if not callable(event_handler):
raise TypeError('event handler should be a function')