|
|
|
@ -15,12 +15,12 @@
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
from optparse import OptionParser
|
|
|
|
|
from py_paddle import swig_paddle, util, DataProviderWrapperConverter
|
|
|
|
|
from paddle.trainer.PyDataProviderWrapper import IndexSlot
|
|
|
|
|
from py_paddle import swig_paddle, DataProviderConverter
|
|
|
|
|
from paddle.trainer.PyDataProvider2 import integer_value_sequence
|
|
|
|
|
from paddle.trainer.config_parser import parse_config
|
|
|
|
|
"""
|
|
|
|
|
Usage: run following command to show help message.
|
|
|
|
|
python predict.py -h
|
|
|
|
|
python predict.py -h
|
|
|
|
|
"""
|
|
|
|
|
UNK_IDX = 0
|
|
|
|
|
|
|
|
|
@ -43,16 +43,22 @@ class Prediction():
|
|
|
|
|
|
|
|
|
|
conf = parse_config(
|
|
|
|
|
train_conf,
|
|
|
|
|
'dict_len=' + str(len_dict) +
|
|
|
|
|
'dict_len=' + str(len_dict) +
|
|
|
|
|
',label_len=' + str(len_label) +
|
|
|
|
|
',is_predict=True')
|
|
|
|
|
self.network = swig_paddle.GradientMachine.createFromConfigProto(
|
|
|
|
|
conf.model_config)
|
|
|
|
|
self.network.loadParameters(model_dir)
|
|
|
|
|
|
|
|
|
|
slots = [IndexSlot(len_dict), IndexSlot(len_dict), IndexSlot(len_dict),
|
|
|
|
|
IndexSlot(len_dict), IndexSlot(len_dict), IndexSlot(2)]
|
|
|
|
|
self.converter = util.DataProviderWrapperConverter(True, slots)
|
|
|
|
|
slots = [
|
|
|
|
|
integer_value_sequence(len_dict),
|
|
|
|
|
integer_value_sequence(len_dict),
|
|
|
|
|
integer_value_sequence(len_dict),
|
|
|
|
|
integer_value_sequence(len_dict),
|
|
|
|
|
integer_value_sequence(len_dict),
|
|
|
|
|
integer_value_sequence(2)
|
|
|
|
|
]
|
|
|
|
|
self.converter = DataProviderConverter(slots)
|
|
|
|
|
|
|
|
|
|
def load_dict_label(self, dict_file, label_file):
|
|
|
|
|
"""
|
|
|
|
@ -109,7 +115,7 @@ class Prediction():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def option_parser():
|
|
|
|
|
usage = ("python predict.py -c config -w model_dir "
|
|
|
|
|
usage = ("python predict.py -c config -w model_dir "
|
|
|
|
|
"-d word dictionary -l label_file -i input_file")
|
|
|
|
|
parser = OptionParser(usage="usage: %s [options]" % usage)
|
|
|
|
|
parser.add_option(
|
|
|
|
|