You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
5.3 KiB
156 lines
5.3 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
imdb dataset parser.
|
|
"""
|
|
import os
|
|
from itertools import chain
|
|
|
|
import numpy as np
|
|
import gensim
|
|
|
|
|
|
class ImdbParser():
|
|
"""
|
|
parse aclImdb data to features and labels.
|
|
sentence->tokenized->encoded->padding->features
|
|
"""
|
|
|
|
def __init__(self, imdb_path, glove_path, embed_size=300):
|
|
self.__segs = ['train', 'test']
|
|
self.__label_dic = {'pos': 1, 'neg': 0}
|
|
self.__imdb_path = imdb_path
|
|
self.__glove_dim = embed_size
|
|
self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt')
|
|
|
|
# properties
|
|
self.__imdb_datas = {}
|
|
self.__features = {}
|
|
self.__labels = {}
|
|
self.__vacab = {}
|
|
self.__word2idx = {}
|
|
self.__weight_np = {}
|
|
self.__wvmodel = None
|
|
|
|
def parse(self):
|
|
"""
|
|
parse imdb data to memory
|
|
"""
|
|
self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file)
|
|
|
|
for seg in self.__segs:
|
|
self.__parse_imdb_datas(seg)
|
|
self.__parse_features_and_labels(seg)
|
|
self.__gen_weight_np(seg)
|
|
|
|
def __parse_imdb_datas(self, seg):
|
|
"""
|
|
load data from txt
|
|
"""
|
|
data_lists = []
|
|
for label_name, label_id in self.__label_dic.items():
|
|
sentence_dir = os.path.join(self.__imdb_path, seg, label_name)
|
|
for file in os.listdir(sentence_dir):
|
|
with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f:
|
|
sentence = f.read().replace('\n', '')
|
|
data_lists.append([sentence, label_id])
|
|
self.__imdb_datas[seg] = data_lists
|
|
|
|
def __parse_features_and_labels(self, seg):
|
|
"""
|
|
parse features and labels
|
|
"""
|
|
features = []
|
|
labels = []
|
|
for sentence, label in self.__imdb_datas[seg]:
|
|
features.append(sentence)
|
|
labels.append(label)
|
|
|
|
self.__features[seg] = features
|
|
self.__labels[seg] = labels
|
|
|
|
# update feature to tokenized
|
|
self.__updata_features_to_tokenized(seg)
|
|
# parse vacab
|
|
self.__parse_vacab(seg)
|
|
# encode feature
|
|
self.__encode_features(seg)
|
|
# padding feature
|
|
self.__padding_features(seg)
|
|
|
|
def __updata_features_to_tokenized(self, seg):
|
|
tokenized_features = []
|
|
for sentence in self.__features[seg]:
|
|
tokenized_sentence = [word.lower() for word in sentence.split(" ")]
|
|
tokenized_features.append(tokenized_sentence)
|
|
self.__features[seg] = tokenized_features
|
|
|
|
def __parse_vacab(self, seg):
|
|
# vocab
|
|
tokenized_features = self.__features[seg]
|
|
vocab = set(chain(*tokenized_features))
|
|
self.__vacab[seg] = vocab
|
|
|
|
# word_to_idx: {'hello': 1, 'world':111, ... '<unk>': 0}
|
|
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
|
|
word_to_idx['<unk>'] = 0
|
|
self.__word2idx[seg] = word_to_idx
|
|
|
|
def __encode_features(self, seg):
|
|
""" encode word to index """
|
|
word_to_idx = self.__word2idx['train']
|
|
encoded_features = []
|
|
for tokenized_sentence in self.__features[seg]:
|
|
encoded_sentence = []
|
|
for word in tokenized_sentence:
|
|
encoded_sentence.append(word_to_idx.get(word, 0))
|
|
encoded_features.append(encoded_sentence)
|
|
self.__features[seg] = encoded_features
|
|
|
|
def __padding_features(self, seg, maxlen=500, pad=0):
|
|
""" pad all features to the same length """
|
|
padded_features = []
|
|
for feature in self.__features[seg]:
|
|
if len(feature) >= maxlen:
|
|
padded_feature = feature[:maxlen]
|
|
else:
|
|
padded_feature = feature
|
|
while len(padded_feature) < maxlen:
|
|
padded_feature.append(pad)
|
|
padded_features.append(padded_feature)
|
|
self.__features[seg] = padded_features
|
|
|
|
def __gen_weight_np(self, seg):
|
|
"""
|
|
generate weight by gensim
|
|
"""
|
|
weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32)
|
|
for word, idx in self.__word2idx[seg].items():
|
|
if word not in self.__wvmodel:
|
|
continue
|
|
word_vector = self.__wvmodel.get_vector(word)
|
|
weight_np[idx, :] = word_vector
|
|
|
|
self.__weight_np[seg] = weight_np
|
|
|
|
def get_datas(self, seg):
|
|
"""
|
|
return features, labels, and weight
|
|
"""
|
|
features = np.array(self.__features[seg]).astype(np.int32)
|
|
labels = np.array(self.__labels[seg]).astype(np.int32)
|
|
weight = np.array(self.__weight_np[seg])
|
|
return features, labels, weight
|