Support predicting the samples from sys.stdin

avx_docs
dangqingqing 8 years ago
parent db3798117b
commit aaecfcc47f

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os, sys
import numpy as np import numpy as np
from optparse import OptionParser from optparse import OptionParser
from py_paddle import swig_paddle, DataProviderConverter from py_paddle import swig_paddle, DataProviderConverter
@ -66,35 +66,42 @@ class SentimentPrediction():
for v in open(label_file, 'r'): for v in open(label_file, 'r'):
self.label[int(v.split('\t')[1])] = v.split('\t')[0] self.label[int(v.split('\t')[1])] = v.split('\t')[0]
def get_data(self, data_file): def get_data(self, data):
""" """
Get input data of paddle format. Get input data of paddle format.
""" """
with open(data_file, 'r') as fdata: for line in data:
for line in fdata: words = line.strip().split()
words = line.strip().split() word_slot = [
word_slot = [ self.word_dict[w] for w in words if w in self.word_dict
self.word_dict[w] for w in words if w in self.word_dict ]
] if not word_slot:
if not word_slot: print "all words are not in dictionary: %s", line
print "all words are not in dictionary: %s", line continue
continue yield [word_slot]
yield [word_slot]
def predict(self, batch_size):
def predict(self, data_file):
""" def batch_predict(batch_data):
data_file: file name of input data. input = self.converter(self.get_data(batch_data))
""" output = self.network.forwardTest(input)
input = self.converter(self.get_data(data_file)) prob = output[0]["value"]
output = self.network.forwardTest(input) labs = np.argsort(-prob)
prob = output[0]["value"] for idx, lab in enumerate(labs):
lab = np.argsort(-prob) if self.label is None:
if self.label is None: print("predicting label is %d" % (lab[0]))
print("%s: predicting label is %d" % (data_file, lab[0][0])) else:
else: print("predicting label is %s" %
print("%s: predicting label is %s" % (self.label[lab[0]]))
(data_file, self.label[lab[0][0]]))
batch = []
for line in sys.stdin:
batch.append(line)
if len(batch) == batch_size:
batch_predict(batch)
batch=[]
if len(batch) > 0:
batch_predict(batch)
def option_parser(): def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file " usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
@ -119,11 +126,13 @@ def option_parser():
default=None, default=None,
help="dictionary file") help="dictionary file")
parser.add_option( parser.add_option(
"-i", "-c",
"--data", "--batch_size",
type="int",
action="store", action="store",
dest="data", dest="batch_size",
help="data file to predict") default=1,
help="the batch size for prediction")
parser.add_option( parser.add_option(
"-w", "-w",
"--model", "--model",
@ -137,13 +146,13 @@ def option_parser():
def main(): def main():
options, args = option_parser() options, args = option_parser()
train_conf = options.train_conf train_conf = options.train_conf
data = options.data batch_size = options.batch_size
dict_file = options.dict_file dict_file = options.dict_file
model_path = options.model_path model_path = options.model_path
label = options.label label = options.label
swig_paddle.initPaddle("--use_gpu=0") swig_paddle.initPaddle("--use_gpu=0")
predict = SentimentPrediction(train_conf, dict_file, model_path, label) predict = SentimentPrediction(train_conf, dict_file, model_path, label)
predict.predict(data) predict.predict(batch_size)
if __name__ == '__main__': if __name__ == '__main__':

@ -19,9 +19,9 @@ set -e
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
-n $config\ --tconf=$config\
-w $model \ --model=$model \
-b $label \ --label=$label \
-d ./data/pre-imdb/dict.txt \ --dict=./data/pre-imdb/dict.txt \
-i ./data/aclImdb/test/pos/10007_10.txt --batch_size=1

Loading…
Cancel
Save