|
|
|
@ -34,9 +34,9 @@ from common import download
|
|
|
|
|
import tarfile
|
|
|
|
|
import scipy.io as scio
|
|
|
|
|
from paddle.v2.image import *
|
|
|
|
|
from paddle.v2.reader import *
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
from multiprocessing import cpu_count
|
|
|
|
|
__all__ = ['train', 'test', 'valid']
|
|
|
|
|
|
|
|
|
@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
|
|
|
|
|
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
|
|
|
|
|
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
|
|
|
|
|
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
|
|
|
|
|
# In official 'readme', tstid is the flag of test data
|
|
|
|
|
# and trnid is the flag of train data. But test data is more than train data.
|
|
|
|
|
# So we exchange the train data and test data.
|
|
|
|
|
TRAIN_FLAG = 'tstid'
|
|
|
|
|
TEST_FLAG = 'trnid'
|
|
|
|
|
VALID_FLAG = 'valid'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_mapper(sample):
|
|
|
|
@ -53,8 +59,8 @@ def default_mapper(sample):
|
|
|
|
|
map image bytes data to type needed by model input layer
|
|
|
|
|
'''
|
|
|
|
|
img, label = sample
|
|
|
|
|
img = paddle.image.load_image_bytes(img)
|
|
|
|
|
img = paddle.image.simple_transform(img, 256, 224, True)
|
|
|
|
|
img = load_image_bytes(img)
|
|
|
|
|
img = simple_transform(img, 256, 224, True)
|
|
|
|
|
return img.flatten().astype('float32'), label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -63,7 +69,8 @@ def reader_creator(data_file,
|
|
|
|
|
setid_file,
|
|
|
|
|
dataset_name,
|
|
|
|
|
mapper=default_mapper,
|
|
|
|
|
buffered_size=1024):
|
|
|
|
|
buffered_size=1024,
|
|
|
|
|
use_xmap=True):
|
|
|
|
|
'''
|
|
|
|
|
1. read images from tar file and
|
|
|
|
|
merge images into batch files in 102flowers.tgz_batch/
|
|
|
|
@ -105,11 +112,13 @@ def reader_creator(data_file,
|
|
|
|
|
for sample, label in itertools.izip(data, batch['label']):
|
|
|
|
|
yield sample, int(label)
|
|
|
|
|
|
|
|
|
|
return paddle.reader.xmap_readers(mapper, reader,
|
|
|
|
|
cpu_count(), buffered_size)
|
|
|
|
|
if use_xmap:
|
|
|
|
|
return xmap_readers(mapper, reader, cpu_count(), buffered_size)
|
|
|
|
|
else:
|
|
|
|
|
return map_readers(mapper, reader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(mapper=default_mapper, buffered_size=1024):
|
|
|
|
|
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
|
|
|
|
|
'''
|
|
|
|
|
Create flowers training set reader.
|
|
|
|
|
It returns a reader, each sample in the reader is
|
|
|
|
@ -128,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024):
|
|
|
|
|
return reader_creator(
|
|
|
|
|
download(DATA_URL, 'flowers', DATA_MD5),
|
|
|
|
|
download(LABEL_URL, 'flowers', LABEL_MD5),
|
|
|
|
|
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
|
|
|
|
|
buffered_size)
|
|
|
|
|
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
|
|
|
|
|
buffered_size, use_xmap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(mapper=default_mapper, buffered_size=1024):
|
|
|
|
|
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
|
|
|
|
|
'''
|
|
|
|
|
Create flowers test set reader.
|
|
|
|
|
It returns a reader, each sample in the reader is
|
|
|
|
@ -151,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024):
|
|
|
|
|
return reader_creator(
|
|
|
|
|
download(DATA_URL, 'flowers', DATA_MD5),
|
|
|
|
|
download(LABEL_URL, 'flowers', LABEL_MD5),
|
|
|
|
|
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
|
|
|
|
|
buffered_size)
|
|
|
|
|
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
|
|
|
|
|
buffered_size, use_xmap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def valid(mapper=default_mapper, buffered_size=1024):
|
|
|
|
|
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
|
|
|
|
|
'''
|
|
|
|
|
Create flowers validation set reader.
|
|
|
|
|
It returns a reader, each sample in the reader is
|
|
|
|
@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024):
|
|
|
|
|
return reader_creator(
|
|
|
|
|
download(DATA_URL, 'flowers', DATA_MD5),
|
|
|
|
|
download(LABEL_URL, 'flowers', LABEL_MD5),
|
|
|
|
|
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
|
|
|
|
|
buffered_size)
|
|
|
|
|
download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
|
|
|
|
|
buffered_size, use_xmap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch():
|
|
|
|
|