|
|
|
@ -73,31 +73,34 @@ def main():
|
|
|
|
|
cost = seqToseq_net_v2(source_dict_dim, target_dict_dim)
|
|
|
|
|
parameters = paddle.parameters.create(cost)
|
|
|
|
|
|
|
|
|
|
# define optimize method and trainer
|
|
|
|
|
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
|
|
|
|
|
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 10 == 0:
|
|
|
|
|
print "Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
|
|
|
|
|
trainer = paddle.trainer.SGD(cost=cost,
|
|
|
|
|
parameters=parameters,
|
|
|
|
|
update_equation=optimizer)
|
|
|
|
|
|
|
|
|
|
# define data reader
|
|
|
|
|
reader_dict = {
|
|
|
|
|
'source_language_word': 0,
|
|
|
|
|
'target_language_word': 1,
|
|
|
|
|
'target_language_next_word': 2
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
trn_reader = paddle.reader.batched(
|
|
|
|
|
wmt14_reader = paddle.reader.batched(
|
|
|
|
|
paddle.reader.shuffle(
|
|
|
|
|
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
|
|
|
|
|
batch_size=5)
|
|
|
|
|
|
|
|
|
|
# define event_handler callback
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 10 == 0:
|
|
|
|
|
print "Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
|
|
|
|
|
# start to train
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=trn_reader,
|
|
|
|
|
reader=wmt14_reader,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
num_passes=10000,
|
|
|
|
|
reader_dict=reader_dict)
|
|
|
|
|