You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							231 lines
						
					
					
						
							7.8 KiB
						
					
					
				
			
		
		
	
	
							231 lines
						
					
					
						
							7.8 KiB
						
					
					
				| # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| This module will download dataset from
 | |
| http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
 | |
| and parse train/test set intopaddle reader creators.
 | |
| 
 | |
| This set contains images of flowers belonging to 102 different categories.
 | |
| The images were acquired by searching the web and taking pictures. There are a
 | |
| minimum of 40 images for each category.
 | |
| 
 | |
| The database was used in:
 | |
| 
 | |
| Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
 | |
|  number of classes.Proceedings of the Indian Conference on Computer Vision,
 | |
| Graphics and Image Processing (2008)
 | |
| http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
 | |
| 
 | |
| """
 | |
| 
 | |
| from __future__ import print_function
 | |
| 
 | |
| import itertools
 | |
| import functools
 | |
| from .common import download
 | |
| import tarfile
 | |
| import scipy.io as scio
 | |
| from paddle.dataset.image import *
 | |
| from paddle.reader import map_readers, xmap_readers
 | |
| from paddle import compat as cpt
 | |
| import os
 | |
| import numpy as np
 | |
| from multiprocessing import cpu_count
 | |
| import six
 | |
| from six.moves import cPickle as pickle
 | |
| __all__ = ['train', 'test', 'valid']
 | |
| 
 | |
| DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
 | |
| LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
 | |
| SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
 | |
| DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
 | |
| LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
 | |
| SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
 | |
| # In official 'readme', tstid is the flag of test data
 | |
| # and trnid is the flag of train data. But test data is more than train data.
 | |
| # So we exchange the train data and test data.
 | |
| TRAIN_FLAG = 'tstid'
 | |
| TEST_FLAG = 'trnid'
 | |
| VALID_FLAG = 'valid'
 | |
| 
 | |
| 
 | |
| def default_mapper(is_train, sample):
 | |
|     '''
 | |
|     map image bytes data to type needed by model input layer
 | |
|     '''
 | |
|     img, label = sample
 | |
|     img = load_image_bytes(img)
 | |
|     img = simple_transform(
 | |
|         img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
 | |
|     return img.flatten().astype('float32'), label
 | |
| 
 | |
| 
 | |
| train_mapper = functools.partial(default_mapper, True)
 | |
| test_mapper = functools.partial(default_mapper, False)
 | |
| 
 | |
| 
 | |
| def reader_creator(data_file,
 | |
|                    label_file,
 | |
|                    setid_file,
 | |
|                    dataset_name,
 | |
|                    mapper,
 | |
|                    buffered_size=1024,
 | |
|                    use_xmap=True,
 | |
|                    cycle=False):
 | |
|     '''
 | |
|     1. read images from tar file and
 | |
|         merge images into batch files in 102flowers.tgz_batch/
 | |
|     2. get a reader to read sample from batch file
 | |
| 
 | |
|     :param data_file: downloaded data file
 | |
|     :type data_file: string
 | |
|     :param label_file: downloaded label file
 | |
|     :type label_file: string
 | |
|     :param setid_file: downloaded setid file containing information
 | |
|                         about how to split dataset
 | |
|     :type setid_file: string
 | |
|     :param dataset_name: data set name (tstid|trnid|valid)
 | |
|     :type dataset_name: string
 | |
|     :param mapper: a function to map image bytes data to type
 | |
|                     needed by model input layer
 | |
|     :type mapper: callable
 | |
|     :param buffered_size: the size of buffer used to process images
 | |
|     :type buffered_size: int
 | |
|     :param cycle: whether to cycle through the dataset
 | |
|     :type cycle: bool
 | |
|     :return: data reader
 | |
|     :rtype: callable
 | |
|     '''
 | |
|     labels = scio.loadmat(label_file)['labels'][0]
 | |
|     indexes = scio.loadmat(setid_file)[dataset_name][0]
 | |
|     img2label = {}
 | |
|     for i in indexes:
 | |
|         img = "jpg/image_%05d.jpg" % i
 | |
|         img2label[img] = labels[i - 1]
 | |
|     file_list = batch_images_from_tar(data_file, dataset_name, img2label)
 | |
| 
 | |
|     def reader():
 | |
|         while True:
 | |
|             with open(file_list, 'r') as f_list:
 | |
|                 for file in f_list:
 | |
|                     file = file.strip()
 | |
|                     batch = None
 | |
|                     with open(file, 'rb') as f:
 | |
|                         if six.PY2:
 | |
|                             batch = pickle.load(f)
 | |
|                         else:
 | |
|                             batch = pickle.load(f, encoding='bytes')
 | |
| 
 | |
|                         if six.PY3:
 | |
|                             batch = cpt.to_text(batch)
 | |
|                         data_batch = batch['data']
 | |
|                         labels_batch = batch['label']
 | |
|                         for sample, label in six.moves.zip(data_batch,
 | |
|                                                            labels_batch):
 | |
|                             yield sample, int(label) - 1
 | |
|             if not cycle:
 | |
|                 break
 | |
| 
 | |
|     if use_xmap:
 | |
|         return xmap_readers(mapper, reader, min(4, cpu_count()), buffered_size)
 | |
|     else:
 | |
|         return map_readers(mapper, reader)
 | |
| 
 | |
| 
 | |
| def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False):
 | |
|     '''
 | |
|     Create flowers training set reader.
 | |
|     It returns a reader, each sample in the reader is
 | |
|     image pixels in [0, 1] and label in [1, 102]
 | |
|     translated from original color image by steps:
 | |
|     1. resize to 256*256
 | |
|     2. random crop to 224*224
 | |
|     3. flatten
 | |
|     :param mapper:  a function to map sample.
 | |
|     :type mapper: callable
 | |
|     :param buffered_size: the size of buffer used to process images
 | |
|     :type buffered_size: int
 | |
|     :param cycle: whether to cycle through the dataset
 | |
|     :type cycle: bool
 | |
|     :return: train data reader
 | |
|     :rtype: callable
 | |
|     '''
 | |
|     return reader_creator(
 | |
|         download(DATA_URL, 'flowers', DATA_MD5),
 | |
|         download(LABEL_URL, 'flowers', LABEL_MD5),
 | |
|         download(SETID_URL, 'flowers', SETID_MD5),
 | |
|         TRAIN_FLAG,
 | |
|         mapper,
 | |
|         buffered_size,
 | |
|         use_xmap,
 | |
|         cycle=cycle)
 | |
| 
 | |
| 
 | |
| def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False):
 | |
|     '''
 | |
|     Create flowers test set reader.
 | |
|     It returns a reader, each sample in the reader is
 | |
|     image pixels in [0, 1] and label in [1, 102]
 | |
|     translated from original color image by steps:
 | |
|     1. resize to 256*256
 | |
|     2. random crop to 224*224
 | |
|     3. flatten
 | |
|     :param mapper:  a function to map sample.
 | |
|     :type mapper: callable
 | |
|     :param buffered_size: the size of buffer used to process images
 | |
|     :type buffered_size: int
 | |
|     :param cycle: whether to cycle through the dataset
 | |
|     :type cycle: bool
 | |
|     :return: test data reader
 | |
|     :rtype: callable
 | |
|     '''
 | |
|     return reader_creator(
 | |
|         download(DATA_URL, 'flowers', DATA_MD5),
 | |
|         download(LABEL_URL, 'flowers', LABEL_MD5),
 | |
|         download(SETID_URL, 'flowers', SETID_MD5),
 | |
|         TEST_FLAG,
 | |
|         mapper,
 | |
|         buffered_size,
 | |
|         use_xmap,
 | |
|         cycle=cycle)
 | |
| 
 | |
| 
 | |
| def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
 | |
|     '''
 | |
|     Create flowers validation set reader.
 | |
|     It returns a reader, each sample in the reader is
 | |
|     image pixels in [0, 1] and label in [1, 102]
 | |
|     translated from original color image by steps:
 | |
|     1. resize to 256*256
 | |
|     2. random crop to 224*224
 | |
|     3. flatten
 | |
|     :param mapper:  a function to map sample.
 | |
|     :type mapper: callable
 | |
|     :param buffered_size: the size of buffer used to process images
 | |
|     :type buffered_size: int
 | |
|     :return: test data reader
 | |
|     :rtype: callable
 | |
|     '''
 | |
|     return reader_creator(
 | |
|         download(DATA_URL, 'flowers', DATA_MD5),
 | |
|         download(LABEL_URL, 'flowers', LABEL_MD5),
 | |
|         download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
 | |
|         buffered_size, use_xmap)
 | |
| 
 | |
| 
 | |
| def fetch():
 | |
|     download(DATA_URL, 'flowers', DATA_MD5)
 | |
|     download(LABEL_URL, 'flowers', LABEL_MD5)
 | |
|     download(SETID_URL, 'flowers', SETID_MD5)
 |