commit
4a0cb3db6c
@ -0,0 +1 @@
|
||||
./doc/howto/dev/contribute_to_paddle_en.md
|
@ -0,0 +1,31 @@
|
||||
# External dependency to Google protobuf.
|
||||
http_archive(
|
||||
name="protobuf",
|
||||
url="http://github.com/google/protobuf/archive/v3.1.0.tar.gz",
|
||||
sha256="0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7",
|
||||
strip_prefix="protobuf-3.1.0")
|
||||
|
||||
# External dependency to gtest 1.7.0. This method comes from
|
||||
# https://www.bazel.io/versions/master/docs/tutorial/cpp.html.
|
||||
new_http_archive(
|
||||
name="gtest",
|
||||
url="https://github.com/google/googletest/archive/release-1.7.0.zip",
|
||||
sha256="b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0",
|
||||
build_file="third_party/gtest.BUILD",
|
||||
strip_prefix="googletest-release-1.7.0")
|
||||
|
||||
# External dependency to gflags. This method comes from
|
||||
# https://github.com/gflags/example/blob/master/WORKSPACE.
|
||||
new_git_repository(
|
||||
name="gflags",
|
||||
tag="v2.2.0",
|
||||
remote="https://github.com/gflags/gflags.git",
|
||||
build_file="third_party/gflags.BUILD")
|
||||
|
||||
# External dependency to glog. This method comes from
|
||||
# https://github.com/reyoung/bazel_playground/blob/master/WORKSPACE
|
||||
new_git_repository(
|
||||
name="glog",
|
||||
remote="https://github.com/google/glog.git",
|
||||
commit="b6a5e0524c28178985f0d228e9eaa43808dbec3c",
|
||||
build_file="third_party/glog.BUILD")
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,5 @@
|
||||
dataprovider.pyc
|
||||
empty.list
|
||||
train.log
|
||||
output
|
||||
train.list
|
@ -0,0 +1,147 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os, sys
|
||||
import numpy as np
|
||||
from optparse import OptionParser
|
||||
from py_paddle import swig_paddle, DataProviderConverter
|
||||
from paddle.trainer.PyDataProvider2 import sparse_binary_vector
|
||||
from paddle.trainer.config_parser import parse_config
|
||||
"""
|
||||
Usage: run following command to show help message.
|
||||
python api_predict.py -h
|
||||
"""
|
||||
|
||||
|
||||
class QuickStartPrediction():
|
||||
def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
|
||||
"""
|
||||
train_conf: trainer configure.
|
||||
dict_file: word dictionary file name.
|
||||
model_dir: directory of model.
|
||||
"""
|
||||
self.train_conf = train_conf
|
||||
self.dict_file = dict_file
|
||||
self.word_dict = {}
|
||||
self.dict_dim = self.load_dict()
|
||||
self.model_dir = model_dir
|
||||
if model_dir is None:
|
||||
self.model_dir = os.path.dirname(train_conf)
|
||||
|
||||
self.label = None
|
||||
if label_file is not None:
|
||||
self.load_label(label_file)
|
||||
|
||||
conf = parse_config(train_conf, "is_predict=1")
|
||||
self.network = swig_paddle.GradientMachine.createFromConfigProto(
|
||||
conf.model_config)
|
||||
self.network.loadParameters(self.model_dir)
|
||||
input_types = [sparse_binary_vector(self.dict_dim)]
|
||||
self.converter = DataProviderConverter(input_types)
|
||||
|
||||
def load_dict(self):
|
||||
"""
|
||||
Load dictionary from self.dict_file.
|
||||
"""
|
||||
for line_count, line in enumerate(open(self.dict_file, 'r')):
|
||||
self.word_dict[line.strip().split('\t')[0]] = line_count
|
||||
return len(self.word_dict)
|
||||
|
||||
def load_label(self, label_file):
|
||||
"""
|
||||
Load label.
|
||||
"""
|
||||
self.label = {}
|
||||
for v in open(label_file, 'r'):
|
||||
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
|
||||
|
||||
def get_index(self, data):
|
||||
"""
|
||||
transform word into integer index according to the dictionary.
|
||||
"""
|
||||
words = data.strip().split()
|
||||
word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
|
||||
return word_slot
|
||||
|
||||
def batch_predict(self, data_batch):
|
||||
input = self.converter(data_batch)
|
||||
output = self.network.forwardTest(input)
|
||||
prob = output[0]["id"].tolist()
|
||||
print("predicting labels is:")
|
||||
print prob
|
||||
|
||||
|
||||
def option_parser():
|
||||
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
|
||||
parser = OptionParser(usage="usage: %s [options]" % usage)
|
||||
parser.add_option(
|
||||
"-n",
|
||||
"--tconf",
|
||||
action="store",
|
||||
dest="train_conf",
|
||||
help="network config")
|
||||
parser.add_option(
|
||||
"-d",
|
||||
"--dict",
|
||||
action="store",
|
||||
dest="dict_file",
|
||||
help="dictionary file")
|
||||
parser.add_option(
|
||||
"-b",
|
||||
"--label",
|
||||
action="store",
|
||||
dest="label",
|
||||
default=None,
|
||||
help="dictionary file")
|
||||
parser.add_option(
|
||||
"-c",
|
||||
"--batch_size",
|
||||
type="int",
|
||||
action="store",
|
||||
dest="batch_size",
|
||||
default=1,
|
||||
help="the batch size for prediction")
|
||||
parser.add_option(
|
||||
"-w",
|
||||
"--model",
|
||||
action="store",
|
||||
dest="model_path",
|
||||
default=None,
|
||||
help="model path")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
options, args = option_parser()
|
||||
train_conf = options.train_conf
|
||||
batch_size = options.batch_size
|
||||
dict_file = options.dict_file
|
||||
model_path = options.model_path
|
||||
label = options.label
|
||||
swig_paddle.initPaddle("--use_gpu=0")
|
||||
predict = QuickStartPrediction(train_conf, dict_file, model_path, label)
|
||||
|
||||
batch = []
|
||||
labels = []
|
||||
for line in sys.stdin:
|
||||
[label, text] = line.split("\t")
|
||||
labels.append(int(label))
|
||||
batch.append([predict.get_index(text)])
|
||||
print("labels is:")
|
||||
print labels
|
||||
predict.batch_predict(batch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
set -e
|
||||
|
||||
#Note the default model is pass-00002, you shold make sure the model path
|
||||
#exists or change the mode path.
|
||||
#only test on trainer_config.lr.py
|
||||
model=output/pass-00001/
|
||||
config=trainer_config.lr.py
|
||||
label=data/labels.list
|
||||
dict=data/dict.txt
|
||||
batch_size=20
|
||||
head -n$batch_size data/test.txt | python api_predict.py \
|
||||
--tconf=$config\
|
||||
--model=$model \
|
||||
--label=$label \
|
||||
--dict=$dict \
|
||||
--batch_size=$batch_size
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue