You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.3 KiB
91 lines
2.3 KiB
8 years ago
|
import os.path
|
||
|
import io
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
|
||
|
# tflearn
|
||
|
import tflearn
|
||
|
from tflearn.data_utils import to_categorical, pad_sequences
|
||
|
from tflearn.datasets import imdb
|
||
|
|
||
|
|
||
|
FLAGS = tf.app.flags.FLAGS
|
||
|
|
||
|
class DataSet(object):
|
||
|
def __init__(self, data, labels):
|
||
|
assert data.shape[0] == labels.shape[0], (
|
||
|
'data.shape: %s labels.shape: %s' % (data.shape,
|
||
|
labels.shape))
|
||
|
self._num_examples = data.shape[0]
|
||
|
|
||
|
self._data = data
|
||
|
self._labels = labels
|
||
|
self._epochs_completed = 0
|
||
|
self._index_in_epoch = 0
|
||
|
|
||
|
@property
|
||
|
def data(self):
|
||
|
return self._data
|
||
|
|
||
|
@property
|
||
|
def labels(self):
|
||
|
return self._labels
|
||
|
|
||
|
@property
|
||
|
def num_examples(self):
|
||
|
return self._num_examples
|
||
|
|
||
|
@property
|
||
|
def epochs_completed(self):
|
||
|
return self._epochs_completed
|
||
|
|
||
|
def next_batch(self, batch_size):
|
||
|
assert batch_size <= self._num_examples
|
||
|
|
||
|
start = self._index_in_epoch
|
||
|
self._index_in_epoch += batch_size
|
||
|
if self._index_in_epoch > self._num_examples:
|
||
|
# Finished epoch
|
||
|
self._epochs_completed += 1
|
||
|
# Shuffle the data
|
||
|
perm = np.arange(self._num_examples)
|
||
|
np.random.shuffle(perm)
|
||
|
self._data = self._data[perm]
|
||
|
self._labels = self._labels[perm]
|
||
|
# Start next epoch
|
||
|
start = 0
|
||
|
self._index_in_epoch = batch_size
|
||
|
|
||
|
end = self._index_in_epoch
|
||
|
|
||
|
return self._data[start:end], self._labels[start:end]
|
||
|
|
||
|
|
||
|
def create_datasets(file_path, vocab_size=30000, val_fraction=0.0):
|
||
|
|
||
|
# IMDB Dataset loading
|
||
|
train, test, _ = imdb.load_data(path=file_path, n_words=vocab_size,
|
||
|
valid_portion=val_fraction, sort_by_len=False)
|
||
|
trainX, trainY = train
|
||
|
testX, testY = test
|
||
|
|
||
|
# Data preprocessing
|
||
|
# Sequence padding
|
||
|
trainX = pad_sequences(trainX, maxlen=FLAGS.max_len, value=0.)
|
||
|
testX = pad_sequences(testX, maxlen=FLAGS.max_len, value=0.)
|
||
|
# Converting labels to binary vectors
|
||
|
trainY = to_categorical(trainY, nb_classes=2)
|
||
|
testY = to_categorical(testY, nb_classes=2)
|
||
|
|
||
|
train_dataset = DataSet(trainX, trainY)
|
||
|
|
||
|
return train_dataset
|
||
|
|
||
|
|
||
|
def main():
|
||
|
create_datasets('imdb.pkl')
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|