Merge pull request #1438 from reyoung/feature/mnist_reader
MNIST dataset reader implementationavx_docs
commit
111e7710ad
@ -0,0 +1,8 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
__all__ = ['DATA_HOME']
|
||||||
|
|
||||||
|
DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set')
|
||||||
|
|
||||||
|
if not os.path.exists(DATA_HOME):
|
||||||
|
os.makedirs(DATA_HOME)
|
@ -0,0 +1,39 @@
|
|||||||
|
import sklearn.datasets.mldata
|
||||||
|
import sklearn.model_selection
|
||||||
|
import numpy
|
||||||
|
from config import DATA_HOME
|
||||||
|
|
||||||
|
__all__ = ['train_creator', 'test_creator']
|
||||||
|
|
||||||
|
|
||||||
|
def __mnist_reader_creator__(data, target):
|
||||||
|
def reader():
|
||||||
|
n_samples = data.shape[0]
|
||||||
|
for i in xrange(n_samples):
|
||||||
|
yield (data[i] / 255.0).astype(numpy.float32), int(target[i])
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
TEST_SIZE = 10000
|
||||||
|
|
||||||
|
data = sklearn.datasets.mldata.fetch_mldata(
|
||||||
|
"MNIST original", data_home=DATA_HOME)
|
||||||
|
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
|
||||||
|
data.data, data.target, test_size=TEST_SIZE, random_state=0)
|
||||||
|
|
||||||
|
|
||||||
|
def train_creator():
|
||||||
|
return __mnist_reader_creator__(X_train, y_train)
|
||||||
|
|
||||||
|
|
||||||
|
def test_creator():
|
||||||
|
return __mnist_reader_creator__(X_test, y_test)
|
||||||
|
|
||||||
|
|
||||||
|
def unittest():
|
||||||
|
assert len(list(test_creator()())) == TEST_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest()
|
Loading…
Reference in new issue