You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
3.6 KiB
85 lines
3.6 KiB
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from paddle.trainer.PyDataProvider2 import *
|
|
|
|
# id of the word not in dictionary
|
|
UNK_IDX = 0
|
|
|
|
# initializer is called by the framework during initialization.
|
|
# It allows the user to describe the data types and setup the
|
|
# necessary data structure for later use.
|
|
# `settings` is an object. initializer need to properly fill settings.input_types.
|
|
# initializer can also store other data structures needed to be used at process().
|
|
# In this example, dictionary is stored in settings.
|
|
# `dictionay` and `kwargs` are arguments passed from trainer_config.lr.py
|
|
def initializer(settings, dictionary, **kwargs):
|
|
# Put the word dictionary into settings
|
|
settings.word_dict = dictionary
|
|
|
|
# setting.input_types specifies what the data types the data provider
|
|
# generates.
|
|
settings.input_types = [
|
|
# The first input is a sparse_binary_vector,
|
|
# which means each dimension of the vector is either 0 or 1. It is the
|
|
# bag-of-words (BOW) representation of the texts.
|
|
sparse_binary_vector(len(dictionary)),
|
|
# The second input is an integer. It represents the category id of the
|
|
# sample. 2 means there are two labels in the dataset.
|
|
# (1 for positive and 0 for negative)
|
|
integer_value(2)]
|
|
|
|
# Delaring a data provider. It has an initializer 'data_initialzer'.
|
|
# It will cache the generated data of the first pass in memory, so that
|
|
# during later pass, no on-the-fly data generation will be needed.
|
|
# `setting` is the same object used by initializer()
|
|
# `file_name` is the name of a file listed train_list or test_list file given
|
|
# to define_py_data_sources2(). See trainer_config.lr.py.
|
|
@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM)
|
|
def process(settings, file_name):
|
|
# Open the input data file.
|
|
with open(file_name, 'r') as f:
|
|
# Read each line.
|
|
for line in f:
|
|
# Each line contains the label and text of the comment, separated by \t.
|
|
label, comment = line.strip().split('\t')
|
|
|
|
# Split the words into a list.
|
|
words = comment.split()
|
|
|
|
# convert the words into a list of ids by looking them up in word_dict.
|
|
word_vector = [settings.word_dict.get(w, UNK_IDX) for w in words]
|
|
|
|
# Return the features for the current comment. The first is a list
|
|
# of ids representing a 0-1 binary sparse vector of the text,
|
|
# the second is the integer id of the label.
|
|
yield word_vector, int(label)
|
|
|
|
|
|
def predict_initializer(settings, dictionary, **kwargs):
|
|
settings.word_dict = dictionary
|
|
settings.input_types = [
|
|
sparse_binary_vector(len(dictionary))
|
|
]
|
|
|
|
# Declaring a data provider for prediction. The difference with process
|
|
# is that label is not generated.
|
|
@provider(init_hook=predict_initializer, should_shuffle=False)
|
|
def process_predict(settings, file_name):
|
|
with open(file_name, 'r') as f:
|
|
for line in f:
|
|
comment = line.strip()
|
|
word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment]
|
|
yield word_vector
|