|
|
|
@ -6,6 +6,8 @@ import paddle.v2.dataset.conll05 as conll05
|
|
|
|
|
import paddle.v2.evaluator as evaluator
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger('paddle')
|
|
|
|
|
|
|
|
|
|
word_dict, verb_dict, label_dict = conll05.get_dict()
|
|
|
|
|
word_dict_len = len(word_dict)
|
|
|
|
|
label_dict_len = len(label_dict)
|
|
|
|
@ -120,19 +122,7 @@ def load_parameter(file_name, h, w):
|
|
|
|
|
return np.fromfile(f, dtype=np.float32).reshape(h, w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_a_batch(inferer, test_data, tag_dict):
|
|
|
|
|
probs = inferer.infer(input=test_data, field='id')
|
|
|
|
|
assert len(probs) == sum(len(x[0]) for x in test_data)
|
|
|
|
|
for test_sample in test_data:
|
|
|
|
|
start_id = 0
|
|
|
|
|
pre_lab = [
|
|
|
|
|
tag_dict[probs[start_id + i]] for i in xrange(len(test_sample[0]))
|
|
|
|
|
]
|
|
|
|
|
print pre_lab
|
|
|
|
|
start_id += len(test_sample[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(is_predict=False):
|
|
|
|
|
def train():
|
|
|
|
|
paddle.init(use_gpu=False, trainer_count=1)
|
|
|
|
|
|
|
|
|
|
# define network topology
|
|
|
|
@ -189,12 +179,12 @@ def main(is_predict=False):
|
|
|
|
|
def event_handler(event):
|
|
|
|
|
if isinstance(event, paddle.event.EndIteration):
|
|
|
|
|
if event.batch_id % 100 == 0:
|
|
|
|
|
print "Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
|
|
|
|
if event.batch_id % 1000 == 0:
|
|
|
|
|
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, event.cost, event.metrics))
|
|
|
|
|
if event.batch_id and event.batch_id % 1000 == 0:
|
|
|
|
|
result = trainer.test(reader=reader, feeding=feeding)
|
|
|
|
|
print "\nTest with Pass %d, Batch %d, %s" % (
|
|
|
|
|
event.pass_id, event.batch_id, result.metrics)
|
|
|
|
|
logger.info("\nTest with Pass %d, Batch %d, %s" %
|
|
|
|
|
(event.pass_id, event.batch_id, result.metrics))
|
|
|
|
|
|
|
|
|
|
if isinstance(event, paddle.event.EndPass):
|
|
|
|
|
# save parameters
|
|
|
|
@ -202,44 +192,86 @@ def main(is_predict=False):
|
|
|
|
|
parameters.to_tar(f)
|
|
|
|
|
|
|
|
|
|
result = trainer.test(reader=reader, feeding=feeding)
|
|
|
|
|
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
|
|
|
|
|
|
|
|
|
|
if not is_predict:
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=reader,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
num_passes=10,
|
|
|
|
|
feeding=feeding)
|
|
|
|
|
else:
|
|
|
|
|
labels_reverse = {}
|
|
|
|
|
for (k, v) in label_dict.items():
|
|
|
|
|
labels_reverse[v] = k
|
|
|
|
|
test_creator = paddle.dataset.conll05.test()
|
|
|
|
|
logger.info("\nTest with Pass %d, %s" %
|
|
|
|
|
(event.pass_id, result.metrics))
|
|
|
|
|
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=reader,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
num_passes=10,
|
|
|
|
|
feeding=feeding)
|
|
|
|
|
|
|
|
|
|
predict = paddle.layer.crf_decoding(
|
|
|
|
|
size=label_dict_len,
|
|
|
|
|
input=feature_out,
|
|
|
|
|
param_attr=paddle.attr.Param(name='crfw'))
|
|
|
|
|
|
|
|
|
|
test_pass = 0
|
|
|
|
|
with gzip.open('params_pass_%d.tar.gz' % (test_pass)) as f:
|
|
|
|
|
parameters = paddle.parameters.Parameters.from_tar(f)
|
|
|
|
|
inferer = paddle.inference.Inference(
|
|
|
|
|
output_layer=predict, parameters=parameters)
|
|
|
|
|
def infer_a_batch(inferer, test_data, word_dict, pred_dict, label_dict):
|
|
|
|
|
probs = inferer.infer(input=test_data, field='id')
|
|
|
|
|
assert len(probs) == sum(len(x[0]) for x in test_data)
|
|
|
|
|
|
|
|
|
|
# prepare test data
|
|
|
|
|
test_data = []
|
|
|
|
|
test_batch_size = 50
|
|
|
|
|
for idx, test_sample in enumerate(test_data):
|
|
|
|
|
start_id = 0
|
|
|
|
|
pred_str = "%s\t" % (pred_dict[test_sample[6][0]])
|
|
|
|
|
|
|
|
|
|
for idx, item in enumerate(test_creator()):
|
|
|
|
|
test_data.append(item[0:8])
|
|
|
|
|
for w, tag in zip(test_sample[0],
|
|
|
|
|
probs[start_id:start_id + len(test_sample[0])]):
|
|
|
|
|
pred_str += "%s[%s] " % (word_dict[w], label_dict[tag])
|
|
|
|
|
print(pred_str.strip())
|
|
|
|
|
start_id += len(test_sample[0])
|
|
|
|
|
|
|
|
|
|
if idx and (not idx % test_batch_size):
|
|
|
|
|
test_a_batch(inferer, test_data, labels_reverse)
|
|
|
|
|
test_data = []
|
|
|
|
|
test_a_batch(inferer, test_data, labels_reverse)
|
|
|
|
|
test_data = []
|
|
|
|
|
|
|
|
|
|
def infer():
|
|
|
|
|
label_dict_reverse = dict((value, key)
|
|
|
|
|
for key, value in label_dict.iteritems())
|
|
|
|
|
word_dict_reverse = dict((value, key)
|
|
|
|
|
for key, value in word_dict.iteritems())
|
|
|
|
|
pred_dict_reverse = dict((value, key)
|
|
|
|
|
for key, value in verb_dict.iteritems())
|
|
|
|
|
|
|
|
|
|
test_creator = paddle.dataset.conll05.test()
|
|
|
|
|
|
|
|
|
|
paddle.init(use_gpu=False, trainer_count=1)
|
|
|
|
|
|
|
|
|
|
# define network topology
|
|
|
|
|
feature_out = db_lstm()
|
|
|
|
|
predict = paddle.layer.crf_decoding(
|
|
|
|
|
size=label_dict_len,
|
|
|
|
|
input=feature_out,
|
|
|
|
|
param_attr=paddle.attr.Param(name='crfw'))
|
|
|
|
|
|
|
|
|
|
test_pass = 0
|
|
|
|
|
with gzip.open('params_pass_%d.tar.gz' % (test_pass)) as f:
|
|
|
|
|
parameters = paddle.parameters.Parameters.from_tar(f)
|
|
|
|
|
inferer = paddle.inference.Inference(
|
|
|
|
|
output_layer=predict, parameters=parameters)
|
|
|
|
|
|
|
|
|
|
# prepare test data
|
|
|
|
|
test_data = []
|
|
|
|
|
test_batch_size = 50
|
|
|
|
|
|
|
|
|
|
for idx, item in enumerate(test_creator()):
|
|
|
|
|
test_data.append(item[0:8])
|
|
|
|
|
|
|
|
|
|
if idx and (not idx % test_batch_size):
|
|
|
|
|
infer_a_batch(
|
|
|
|
|
inferer,
|
|
|
|
|
test_data,
|
|
|
|
|
word_dict_reverse,
|
|
|
|
|
pred_dict_reverse,
|
|
|
|
|
label_dict_reverse, )
|
|
|
|
|
test_data = []
|
|
|
|
|
infer_a_batch(
|
|
|
|
|
inferer,
|
|
|
|
|
test_data,
|
|
|
|
|
word_dict_reverse,
|
|
|
|
|
pred_dict_reverse,
|
|
|
|
|
label_dict_reverse, )
|
|
|
|
|
test_data = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(is_inferring=False):
|
|
|
|
|
if is_inferring:
|
|
|
|
|
infer()
|
|
|
|
|
else:
|
|
|
|
|
train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
main(is_predict=False)
|
|
|
|
|
main(is_inferring=False)
|
|
|
|
|