You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/demo/semantic_role_labeling/predict.py

165 lines
5.3 KiB

# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from optparse import OptionParser
from py_paddle import swig_paddle, util, DataProviderWrapperConverter
from paddle.trainer.PyDataProviderWrapper import IndexSlot
from paddle.trainer.config_parser import parse_config
"""
Usage: run following command to show help message.
python predict.py -h
"""
UNK_IDX = 0
class Prediction():
def __init__(self, train_conf, dict_file, model_dir, label_file):
"""
train_conf: trainer configure.
dict_file: word dictionary file name.
model_dir: directory of model.
"""
self.dict = {}
self.labels = {}
self.labels_reverse = {}
self.load_dict_label(dict_file, label_file)
len_dict = len(self.dict)
len_label = len(self.labels)
conf = parse_config(
train_conf,
'dict_len=' + str(len_dict) +
',label_len=' + str(len_label) +
',is_predict=True')
self.network = swig_paddle.GradientMachine.createFromConfigProto(
conf.model_config)
self.network.loadParameters(model_dir)
slots = [IndexSlot(len_dict), IndexSlot(len_dict), IndexSlot(len_dict),
IndexSlot(len_dict), IndexSlot(len_dict), IndexSlot(2)]
self.converter = util.DataProviderWrapperConverter(True, slots)
def load_dict_label(self, dict_file, label_file):
"""
Load dictionary from self.dict_file.
"""
for line_count, line in enumerate(open(dict_file, 'r')):
self.dict[line.strip()] = line_count
for line_count, line in enumerate(open(label_file, 'r')):
self.labels[line.strip()] = line_count
self.labels_reverse[line_count] = line.strip()
def get_data(self, data_file):
"""
Get input data of paddle format.
"""
with open(data_file, 'r') as fdata:
for line in fdata:
sentence, predicate, ctx_n1, ctx_0, ctx_p1, mark, label = line.strip(
).split('\t')
words = sentence.split()
sen_len = len(words)
word_slot = [self.dict.get(w, UNK_IDX) for w in words]
predicate_slot = [self.dict.get(predicate, UNK_IDX)] * sen_len
ctx_n1_slot = [self.dict.get(ctx_n1, UNK_IDX)] * sen_len
ctx_0_slot = [self.dict.get(ctx_0, UNK_IDX)] * sen_len
ctx_p1_slot = [self.dict.get(ctx_p1, UNK_IDX)] * sen_len
marks = mark.split()
mark_slot = [int(w) for w in marks]
yield word_slot, predicate_slot, ctx_n1_slot, \
ctx_0_slot, ctx_p1_slot, mark_slot
def predict(self, data_file):
"""
data_file: file name of input data.
"""
input = self.converter(self.get_data(data_file))
output = self.network.forwardTest(input)
prob = output[0]["value"]
lab = list(np.argsort(-prob)[:, 0])
with open(data_file, 'r') as fin, open('predict.res', 'w') as fout:
index = 0
for line in fin:
sen = line.split('\t')[0]
len_sen = len(sen.split())
line_labels = lab[index:index + len_sen]
index += len_sen
fout.write(sen + '\t' + ' '.join([self.labels_reverse[
i] for i in line_labels]) + '\n')
def option_parser():
usage = ("python predict.py -c config -w model_dir "
"-d word dictionary -l label_file -i input_file")
parser = OptionParser(usage="usage: %s [options]" % usage)
parser.add_option(
"-c",
"--tconf",
action="store",
dest="train_conf",
help="network config")
parser.add_option(
"-d",
"--dict",
action="store",
dest="dict_file",
help="dictionary file")
parser.add_option(
"-l",
"--label",
action="store",
dest="label_file",
default=None,
help="label file")
parser.add_option(
"-i",
"--data",
action="store",
dest="data_file",
help="data file to predict")
parser.add_option(
"-w",
"--model",
action="store",
dest="model_path",
default=None,
help="model path")
return parser.parse_args()
def main():
options, args = option_parser()
train_conf = options.train_conf
data_file = options.data_file
dict_file = options.dict_file
model_path = options.model_path
label_file = options.label_file
swig_paddle.initPaddle("--use_gpu=0")
predict = Prediction(train_conf, dict_file, model_path, label_file)
predict.predict(data_file)
if __name__ == '__main__':
main()