|
|
|
@ -4,7 +4,8 @@ import paddle.v2 as paddle
|
|
|
|
|
|
|
|
|
|
from seqToseq_net_v2 import seqToseq_net_v2
|
|
|
|
|
|
|
|
|
|
### Data Definiation
|
|
|
|
|
# Data Definiation.
|
|
|
|
|
# TODO:This code should be merged to dataset package.
|
|
|
|
|
data_dir = "./data/pre-wmt14"
|
|
|
|
|
src_lang_dict = os.path.join(data_dir, 'src.dict')
|
|
|
|
|
trg_lang_dict = os.path.join(data_dir, 'trg.dict')
|
|
|
|
@ -68,15 +69,14 @@ def train_reader(file_name):
|
|
|
|
|
def main():
|
|
|
|
|
paddle.init(use_gpu=False, trainer_count=1)
|
|
|
|
|
|
|
|
|
|
# reader = train_reader("data/pre-wmt14/train/train")
|
|
|
|
|
# define network topology
|
|
|
|
|
cost = seqToseq_net_v2(source_dict_dim, target_dict_dim)
|
|
|
|
|
parameters = paddle.parameters.create(cost)
|
|
|
|
|
optimizer = paddle.optimizer.Adam(batch_size=50, learning_rate=5e-4)
|
|
|
|
|
optimizer = paddle.optimizer.Adam(learning_rate=1e-4)
|
|
|
|
|
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 100 == 0:
|
|
|
|
|
if event.batch_id % 10 == 0:
|
|
|
|
|
print "Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
|
|
|
|
@ -93,12 +93,12 @@ def main():
|
|
|
|
|
trn_reader = paddle.reader.batched(
|
|
|
|
|
paddle.reader.shuffle(
|
|
|
|
|
train_reader("data/pre-wmt14/train/train"), buf_size=8192),
|
|
|
|
|
batch_size=10)
|
|
|
|
|
batch_size=10000)
|
|
|
|
|
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=trn_reader,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
num_passes=10000,
|
|
|
|
|
num_passes=10,
|
|
|
|
|
reader_dict=reader_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|