|
|
|
@ -134,7 +134,6 @@ def data_reader():
|
|
|
|
|
for i, line in enumerate(fdict):
|
|
|
|
|
dictionary[line.split('\t')[0]] = i
|
|
|
|
|
|
|
|
|
|
print('dict len : %d' % (len(dictionary)))
|
|
|
|
|
for line_count, line in enumerate(fdata):
|
|
|
|
|
label, comment = line.strip().split('\t\t')
|
|
|
|
|
label = int(label)
|
|
|
|
@ -165,7 +164,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 1 == 0:
|
|
|
|
|
if event.batch_id % 100 == 0:
|
|
|
|
|
print "Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
|
|
|
|
@ -175,7 +174,8 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=paddle.reader.batched(
|
|
|
|
|
data_reader, batch_size=128),
|
|
|
|
|
paddle.reader.shuffle(
|
|
|
|
|
data_reader, buf_size=4096), batch_size=128),
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
reader_dict={'word': 0,
|
|
|
|
|
'label': 1},
|
|
|
|
|