parent
ad93b8f964
commit
5f6c4af3a5
@ -0,0 +1,30 @@
|
|||||||
|
import numpy
|
||||||
|
|
||||||
|
__all__ = ['read_from_mnist']
|
||||||
|
|
||||||
|
|
||||||
|
def read_from_mnist(filename):
|
||||||
|
imgf = filename + "-images-idx3-ubyte"
|
||||||
|
labelf = filename + "-labels-idx1-ubyte"
|
||||||
|
f = open(imgf, "rb")
|
||||||
|
l = open(labelf, "rb")
|
||||||
|
|
||||||
|
f.read(16)
|
||||||
|
l.read(8)
|
||||||
|
|
||||||
|
# Define number of samples for train/test
|
||||||
|
if "train" in filename:
|
||||||
|
n = 60000
|
||||||
|
else:
|
||||||
|
n = 10000
|
||||||
|
|
||||||
|
images = numpy.fromfile(
|
||||||
|
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
|
||||||
|
images = images / 255.0 * 2.0 - 1.0
|
||||||
|
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
|
||||||
|
|
||||||
|
for i in xrange(n):
|
||||||
|
yield {"pixel": images[i, :], 'label': labels[i]}
|
||||||
|
|
||||||
|
f.close()
|
||||||
|
l.close()
|
Loading…
Reference in new issue