You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/example/lstm_aclImdb/imdb.py

156 lines
5.3 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
imdb dataset parser.
"""
import os
from itertools import chain
import numpy as np
import gensim
class ImdbParser():
"""
parse aclImdb data to features and labels.
sentence->tokenized->encoded->padding->features
"""
def __init__(self, imdb_path, glove_path, embed_size=300):
self.__segs = ['train', 'test']
self.__label_dic = {'pos': 1, 'neg': 0}
self.__imdb_path = imdb_path
self.__glove_dim = embed_size
self.__glove_file = os.path.join(glove_path, 'glove.6B.' + str(self.__glove_dim) + 'd.txt')
# properties
self.__imdb_datas = {}
self.__features = {}
self.__labels = {}
self.__vacab = {}
self.__word2idx = {}
self.__weight_np = {}
self.__wvmodel = None
def parse(self):
"""
parse imdb data to memory
"""
self.__wvmodel = gensim.models.KeyedVectors.load_word2vec_format(self.__glove_file)
for seg in self.__segs:
self.__parse_imdb_datas(seg)
self.__parse_features_and_labels(seg)
self.__gen_weight_np(seg)
def __parse_imdb_datas(self, seg):
"""
load data from txt
"""
data_lists = []
for label_name, label_id in self.__label_dic.items():
sentence_dir = os.path.join(self.__imdb_path, seg, label_name)
for file in os.listdir(sentence_dir):
with open(os.path.join(sentence_dir, file), mode='r', encoding='utf8') as f:
sentence = f.read().replace('\n', '')
data_lists.append([sentence, label_id])
self.__imdb_datas[seg] = data_lists
def __parse_features_and_labels(self, seg):
"""
parse features and labels
"""
features = []
labels = []
for sentence, label in self.__imdb_datas[seg]:
features.append(sentence)
labels.append(label)
self.__features[seg] = features
self.__labels[seg] = labels
# update feature to tokenized
self.__updata_features_to_tokenized(seg)
# parse vacab
self.__parse_vacab(seg)
# encode feature
self.__encode_features(seg)
# padding feature
self.__padding_features(seg)
def __updata_features_to_tokenized(self, seg):
tokenized_features = []
for sentence in self.__features[seg]:
tokenized_sentence = [word.lower() for word in sentence.split(" ")]
tokenized_features.append(tokenized_sentence)
self.__features[seg] = tokenized_features
def __parse_vacab(self, seg):
# vocab
tokenized_features = self.__features[seg]
vocab = set(chain(*tokenized_features))
self.__vacab[seg] = vocab
# word_to_idx: {'hello': 1, 'world':111, ... '<unk>': 0}
word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
word_to_idx['<unk>'] = 0
self.__word2idx[seg] = word_to_idx
def __encode_features(self, seg):
""" encode word to index """
word_to_idx = self.__word2idx['train']
encoded_features = []
for tokenized_sentence in self.__features[seg]:
encoded_sentence = []
for word in tokenized_sentence:
encoded_sentence.append(word_to_idx.get(word, 0))
encoded_features.append(encoded_sentence)
self.__features[seg] = encoded_features
def __padding_features(self, seg, maxlen=500, pad=0):
""" pad all features to the same length """
padded_features = []
for feature in self.__features[seg]:
if len(feature) >= maxlen:
padded_feature = feature[:maxlen]
else:
padded_feature = feature
while len(padded_feature) < maxlen:
padded_feature.append(pad)
padded_features.append(padded_feature)
self.__features[seg] = padded_features
def __gen_weight_np(self, seg):
"""
generate weight by gensim
"""
weight_np = np.zeros((len(self.__word2idx[seg]), self.__glove_dim), dtype=np.float32)
for word, idx in self.__word2idx[seg].items():
if word not in self.__wvmodel:
continue
word_vector = self.__wvmodel.get_vector(word)
weight_np[idx, :] = word_vector
self.__weight_np[seg] = weight_np
def get_datas(self, seg):
"""
return features, labels, and weight
"""
features = np.array(self.__features[seg]).astype(np.int32)
labels = np.array(self.__labels[seg]).astype(np.int32)
weight = np.array(self.__weight_np[seg])
return features, labels, weight