|
|
|
@ -87,20 +87,34 @@ class SGD(ITrainer):
|
|
|
|
|
topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types())
|
|
|
|
|
assert isinstance(gm, api.GradientMachine)
|
|
|
|
|
parameters.append_gradient_machine(gm)
|
|
|
|
|
|
|
|
|
|
gm.randParameters()
|
|
|
|
|
updater = self.__optimizer__.create_local_updater()
|
|
|
|
|
updater.init(gm)
|
|
|
|
|
|
|
|
|
|
gm.start()
|
|
|
|
|
batch_evaluator = gm.makeEvaluator()
|
|
|
|
|
assert isinstance(batch_evaluator, api.Evaluator)
|
|
|
|
|
pass_evaluator = gm.makeEvaluator()
|
|
|
|
|
assert isinstance(pass_evaluator, api.Evaluator)
|
|
|
|
|
out_args = api.Arguments.createArguments(0)
|
|
|
|
|
|
|
|
|
|
feeder = DataFeeder(data_types, reader_dict)
|
|
|
|
|
|
|
|
|
|
for pass_id in xrange(num_passes):
|
|
|
|
|
event_handler(v2_event.BeginPass(pass_id))
|
|
|
|
|
pass_evaluator.start()
|
|
|
|
|
updater.startPass()
|
|
|
|
|
for batch_id, data_batch in enumerate(reader()):
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
|
|
|
|
|
batch_evaluator.start()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.BeginIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id))
|
|
|
|
|
pass_type = updater.startBatch(len(data_batch))
|
|
|
|
|
gm.forwardBackward(feeder(data_batch), out_args, pass_type)
|
|
|
|
|
gm.eval(pass_evaluator)
|
|
|
|
|
gm.eval(batch_evaluator)
|
|
|
|
|
for each_param in gm.getParameters():
|
|
|
|
|
updater.update(each_param)
|
|
|
|
|
# Get cost. We use numpy to calculate total cost for this batch.
|
|
|
|
@ -108,11 +122,17 @@ class SGD(ITrainer):
|
|
|
|
|
cost_vec = cost_vec.copyToNumpyMat()
|
|
|
|
|
cost = cost_vec.sum() / len(data_batch)
|
|
|
|
|
updater.finishBatch(cost)
|
|
|
|
|
batch_evaluator.finish()
|
|
|
|
|
event_handler(
|
|
|
|
|
v2_event.EndIteration(
|
|
|
|
|
pass_id=pass_id, batch_id=batch_id, cost=cost))
|
|
|
|
|
pass_id=pass_id,
|
|
|
|
|
batch_id=batch_id,
|
|
|
|
|
cost=cost,
|
|
|
|
|
evaluator=batch_evaluator))
|
|
|
|
|
|
|
|
|
|
updater.finishPass()
|
|
|
|
|
pass_evaluator.finish()
|
|
|
|
|
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
|
|
|
|
|
gm.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|