@ -15,12 +15,12 @@
import os
import numpy as np
from optparse import OptionParser
from py_paddle import swig_paddle , util, DataProviderWrapp erConverter
from paddle . trainer . PyDataProvider Wrapper import IndexSlot
from py_paddle import swig_paddle , DataProviderConverter
from paddle . trainer . PyDataProvider 2 import integer_value_sequence
from paddle . trainer . config_parser import parse_config
"""
Usage : run following command to show help message .
python predict . py - h
python predict . py - h
"""
UNK_IDX = 0
@ -43,16 +43,22 @@ class Prediction():
conf = parse_config (
train_conf ,
' dict_len= ' + str ( len_dict ) +
' dict_len= ' + str ( len_dict ) +
' ,label_len= ' + str ( len_label ) +
' ,is_predict=True ' )
self . network = swig_paddle . GradientMachine . createFromConfigProto (
conf . model_config )
self . network . loadParameters ( model_dir )
slots = [ IndexSlot ( len_dict ) , IndexSlot ( len_dict ) , IndexSlot ( len_dict ) ,
IndexSlot ( len_dict ) , IndexSlot ( len_dict ) , IndexSlot ( 2 ) ]
self . converter = util . DataProviderWrapperConverter ( True , slots )
slots = [
integer_value_sequence ( len_dict ) ,
integer_value_sequence ( len_dict ) ,
integer_value_sequence ( len_dict ) ,
integer_value_sequence ( len_dict ) ,
integer_value_sequence ( len_dict ) ,
integer_value_sequence ( 2 )
]
self . converter = DataProviderConverter ( slots )
def load_dict_label ( self , dict_file , label_file ) :
"""
@ -109,7 +115,7 @@ class Prediction():
def option_parser ( ) :
usage = ( " python predict.py -c config -w model_dir "
usage = ( " python predict.py -c config -w model_dir "
" -d word dictionary -l label_file -i input_file " )
parser = OptionParser ( usage = " usage: %s [options] " % usage )
parser . add_option (