|
|
|
@ -32,7 +32,7 @@ import itertools
|
|
|
|
|
import numpy
|
|
|
|
|
import paddle.dataset.common
|
|
|
|
|
import tarfile
|
|
|
|
|
from six.moves import zip
|
|
|
|
|
import six
|
|
|
|
|
from six.moves import cPickle as pickle
|
|
|
|
|
|
|
|
|
|
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
|
|
|
|
@ -46,25 +46,22 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
|
|
|
|
|
|
|
|
|
def reader_creator(filename, sub_name, cycle=False):
|
|
|
|
|
def read_batch(batch):
|
|
|
|
|
data = batch['data']
|
|
|
|
|
labels = batch.get('labels', batch.get('fine_labels', None))
|
|
|
|
|
data = batch[six.b('data')]
|
|
|
|
|
labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'), None))
|
|
|
|
|
assert labels is not None
|
|
|
|
|
for sample, label in zip(data, labels):
|
|
|
|
|
for sample, label in six.moves.zip(data, labels):
|
|
|
|
|
yield (sample / 255.0).astype(numpy.float32), int(label)
|
|
|
|
|
|
|
|
|
|
def reader():
|
|
|
|
|
with tarfile.open(filename, mode='r') as f:
|
|
|
|
|
names = (each_item.name for each_item in f
|
|
|
|
|
if sub_name in each_item.name)
|
|
|
|
|
names = [each_item.name for each_item in f if sub_name in each_item.name]
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
for name in names:
|
|
|
|
|
import sys
|
|
|
|
|
print(name)
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
print(f.extractfile(name))
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
batch = pickle.load(f.extractfile(name))
|
|
|
|
|
if six.PY2:
|
|
|
|
|
batch = pickle.load(f.extractfile(name))
|
|
|
|
|
else:
|
|
|
|
|
batch = pickle.load(f.extractfile(name), encoding='bytes')
|
|
|
|
|
for item in read_batch(batch):
|
|
|
|
|
yield item
|
|
|
|
|
if not cycle:
|
|
|
|
|