|
|
|
@ -1,61 +1,53 @@
|
|
|
|
|
import sklearn.datasets.mldata
|
|
|
|
|
import sklearn.model_selection
|
|
|
|
|
import numpy
|
|
|
|
|
from config import DATA_HOME
|
|
|
|
|
|
|
|
|
|
__all__ = ['MNISTReader', 'train_reader_creator', 'test_reader_creator']
|
|
|
|
|
__all__ = ['MNIST', 'train_creator', 'test_creator']
|
|
|
|
|
|
|
|
|
|
DATA_HOME = None
|
|
|
|
|
|
|
|
|
|
def __mnist_reader_creator__(data, target):
|
|
|
|
|
def reader():
|
|
|
|
|
n_samples = data.shape[0]
|
|
|
|
|
for i in xrange(n_samples):
|
|
|
|
|
yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
|
|
|
|
|
|
|
|
|
|
def __mnist_reader__(data, target):
|
|
|
|
|
n_samples = data.shape[0]
|
|
|
|
|
for i in xrange(n_samples):
|
|
|
|
|
yield data[i].astype(numpy.float32), int(target[i])
|
|
|
|
|
return reader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MNISTReader(object):
|
|
|
|
|
class MNIST(object):
|
|
|
|
|
"""
|
|
|
|
|
mnist dataset reader. The `train_reader` and `test_reader` method returns
|
|
|
|
|
a iterator of each sample. Each sample is combined by 784-dim float and a
|
|
|
|
|
one-dim label
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, random_state):
|
|
|
|
|
def __init__(self, random_state=0, test_size=10000, **options):
|
|
|
|
|
data = sklearn.datasets.mldata.fetch_mldata(
|
|
|
|
|
"MNIST original", data_home=DATA_HOME)
|
|
|
|
|
n_train = 60000
|
|
|
|
|
self.X_train, self.X_test, self.y_train, self.y_test = sklearn.model_selection.train_test_split(
|
|
|
|
|
data.data / 255.0,
|
|
|
|
|
data.target.astype("int"),
|
|
|
|
|
train_size=n_train,
|
|
|
|
|
random_state=random_state)
|
|
|
|
|
data.data,
|
|
|
|
|
data.target,
|
|
|
|
|
test_size=test_size,
|
|
|
|
|
random_state=random_state,
|
|
|
|
|
**options)
|
|
|
|
|
|
|
|
|
|
def train_reader(self):
|
|
|
|
|
return __mnist_reader__(self.X_train, self.y_train)
|
|
|
|
|
def train_creator(self):
|
|
|
|
|
return __mnist_reader_creator__(self.X_train, self.y_train)
|
|
|
|
|
|
|
|
|
|
def test_reader(self):
|
|
|
|
|
return __mnist_reader__(self.X_test, self.y_test)
|
|
|
|
|
def test_creator(self):
|
|
|
|
|
return __mnist_reader_creator__(self.X_test, self.y_test)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__default_instance__ = MNISTReader(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_reader_creator():
|
|
|
|
|
"""
|
|
|
|
|
Default train set reader creator.
|
|
|
|
|
"""
|
|
|
|
|
return __default_instance__.train_reader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reader_creator():
|
|
|
|
|
"""
|
|
|
|
|
Default test set reader creator.
|
|
|
|
|
"""
|
|
|
|
|
return __default_instance__.test_reader
|
|
|
|
|
__default_instance__ = MNIST()
|
|
|
|
|
train_creator = __default_instance__.train_creator
|
|
|
|
|
test_creator = __default_instance__.test_creator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unittest():
|
|
|
|
|
assert len(list(train_reader_creator()())) == 60000
|
|
|
|
|
size = 12045
|
|
|
|
|
mnist = MNIST(test_size=size)
|
|
|
|
|
assert len(list(mnist.test_creator()())) == size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|