|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
import sys
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -104,7 +105,9 @@ def main():
|
|
|
|
|
parameters = paddle.parameters.create(cost)
|
|
|
|
|
|
|
|
|
|
# define optimize method and trainer
|
|
|
|
|
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
|
|
|
|
|
optimizer = paddle.optimizer.Adam(
|
|
|
|
|
learning_rate=5e-5,
|
|
|
|
|
regularization=paddle.optimizer.L2Regularization(rate=1e-3))
|
|
|
|
|
trainer = paddle.trainer.SGD(cost=cost,
|
|
|
|
|
parameters=parameters,
|
|
|
|
|
update_equation=optimizer)
|
|
|
|
@ -125,8 +128,11 @@ def main():
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 10 == 0:
|
|
|
|
|
print "Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
print "\nPass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
else:
|
|
|
|
|
sys.stdout.write('.')
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
# start to train
|
|
|
|
|
trainer.train(
|
|
|
|
|