From 976a6982baa8a5fb4839d036f7f403ef4b22b0c7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 24 Feb 2017 16:46:14 +0800 Subject: [PATCH 01/12] Add cifar dataset --- python/paddle/v2/data_set/cifar.py | 173 +++++++++++++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 python/paddle/v2/data_set/cifar.py diff --git a/python/paddle/v2/data_set/cifar.py b/python/paddle/v2/data_set/cifar.py new file mode 100644 index 0000000000..54289430d4 --- /dev/null +++ b/python/paddle/v2/data_set/cifar.py @@ -0,0 +1,173 @@ +""" +CIFAR Dataset. + +URL: https://www.cs.toronto.edu/~kriz/cifar.html + +the default train_creator, test_creator used for CIFAR-10 dataset. +""" +from config import DATA_HOME +import os +import hashlib +import urllib2 +import shutil +import tarfile +import cPickle +import itertools +import numpy + +__all__ = ['CIFAR10', 'CIFAR100', 'train_creator', 'test_creator'] + + +def __download_file__(filename, url, md5): + def __file_ok__(): + if not os.path.exists(filename): + return False + md5_hash = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + + return md5_hash.hexdigest() == md5 + + while not __file_ok__(): + response = urllib2.urlopen(url) + with open(filename, mode='wb') as of: + shutil.copyfileobj(fsrc=response, fdst=of) + + +def __read_one_batch__(batch): + data = batch['data'] + labels = batch.get('labels', batch.get('fine_labels', None)) + assert labels is not None + for sample, label in itertools.izip(data, labels): + yield (sample / 255.0).astype(numpy.float32), int(label) + + +CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' +CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' +CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' +CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' + + +class CIFAR(object): + """ + CIFAR dataset reader. The base class for CIFAR-10 and CIFAR-100 + + :param url: Download url. + :param md5: File md5sum + :param meta_filename: Meta file name in package. + :param train_filename: Train file name in package. + :param test_filename: Test file name in package. + """ + + def __init__(self, url, md5, meta_filename, train_filename, test_filename): + filename = os.path.split(url)[-1] + assert DATA_HOME is not None + filepath = os.path.join(DATA_HOME, md5) + if not os.path.exists(filepath): + os.makedirs(filepath) + + self.__full_file__ = os.path.join(filepath, filename) + self.__meta_filename__ = meta_filename + self.__train_filename__ = train_filename + self.__test_filename__ = test_filename + __download_file__(filename=self.__full_file__, url=url, md5=md5) + + def labels(self): + """ + labels get all dataset label in order. + :return: a list of label. + :rtype: list[string] + """ + with tarfile.open(self.__full_file__, mode='r') as f: + name = [ + each_item.name for each_item in f + if self.__meta_filename__ in each_item.name + ][0] + meta_f = f.extractfile(name) + meta = cPickle.load(meta_f) + for key in meta: + if 'label' in key: + return meta[key] + else: + raise RuntimeError("Unexpected branch.") + + def train(self): + """ + Train Reader + """ + return self.__read_batch__(self.__train_filename__) + + def test(self): + """ + Test Reader + """ + return self.__read_batch__(self.__test_filename__) + + def __read_batch__(self, sub_name): + with tarfile.open(self.__full_file__, mode='r') as f: + names = (each_item.name for each_item in f + if sub_name in each_item.name) + + for name in names: + batch = cPickle.load(f.extractfile(name)) + for item in __read_one_batch__(batch): + yield item + + +class CIFAR10(CIFAR): + """ + CIFAR-10 dataset, images are classified in 10 classes. + """ + + def __init__(self): + super(CIFAR10, self).__init__( + CIFAR10_URL, + CIFAR10_MD5, + meta_filename='batches.meta', + train_filename='data_batch', + test_filename='test_batch') + + +class CIFAR100(CIFAR): + """ + CIFAR-100 dataset, images are classified in 100 classes. + """ + + def __init__(self): + super(CIFAR100, self).__init__( + CIFAR100_URL, + CIFAR100_MD5, + meta_filename='meta', + train_filename='train', + test_filename='test') + + +def train_creator(): + """ + Default train reader creator. Use CIFAR-10 dataset. + """ + cifar = CIFAR10() + return cifar.train + + +def test_creator(): + """ + Default test reader creator. Use CIFAR-10 dataset. + """ + cifar = CIFAR10() + return cifar.test + + +def unittest(label_count=100): + cifar = globals()["CIFAR%d" % label_count]() + assert len(cifar.labels()) == label_count + for _ in cifar.test(): + pass + for _ in cifar.train(): + pass + + +if __name__ == '__main__': + unittest(10) + unittest(100) From 434ada47ef0bb039192d4de5d969fc70ae033a0b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 27 Feb 2017 17:05:18 +0800 Subject: [PATCH 02/12] Up to date --- python/paddle/v2/{data_set => dataset}/cifar.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename python/paddle/v2/{data_set => dataset}/cifar.py (100%) diff --git a/python/paddle/v2/data_set/cifar.py b/python/paddle/v2/dataset/cifar.py similarity index 100% rename from python/paddle/v2/data_set/cifar.py rename to python/paddle/v2/dataset/cifar.py From 0bcc4d48defeb00f191c04d868098523965bc0d2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 27 Feb 2017 17:19:29 +0800 Subject: [PATCH 03/12] Simplize cifar --- python/paddle/v2/dataset/cifar.py | 170 ++++++++++-------------------- 1 file changed, 53 insertions(+), 117 deletions(-) diff --git a/python/paddle/v2/dataset/cifar.py b/python/paddle/v2/dataset/cifar.py index 54289430d4..9a999de7e0 100644 --- a/python/paddle/v2/dataset/cifar.py +++ b/python/paddle/v2/dataset/cifar.py @@ -15,33 +15,10 @@ import cPickle import itertools import numpy -__all__ = ['CIFAR10', 'CIFAR100', 'train_creator', 'test_creator'] - - -def __download_file__(filename, url, md5): - def __file_ok__(): - if not os.path.exists(filename): - return False - md5_hash = hashlib.md5() - with open(filename, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5_hash.update(chunk) - - return md5_hash.hexdigest() == md5 - - while not __file_ok__(): - response = urllib2.urlopen(url) - with open(filename, mode='wb') as of: - shutil.copyfileobj(fsrc=response, fdst=of) - - -def __read_one_batch__(batch): - data = batch['data'] - labels = batch.get('labels', batch.get('fine_labels', None)) - assert labels is not None - for sample, label in itertools.izip(data, labels): - yield (sample / 255.0).astype(numpy.float32), int(label) - +__all__ = [ + 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', + 'test_creator' +] CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a' @@ -49,125 +26,84 @@ CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' -class CIFAR(object): - """ - CIFAR dataset reader. The base class for CIFAR-10 and CIFAR-100 - - :param url: Download url. - :param md5: File md5sum - :param meta_filename: Meta file name in package. - :param train_filename: Train file name in package. - :param test_filename: Test file name in package. - """ +def __read_batch__(filename, sub_name): + def reader(): + def __read_one_batch_impl__(batch): + data = batch['data'] + labels = batch.get('labels', batch.get('fine_labels', None)) + assert labels is not None + for sample, label in itertools.izip(data, labels): + yield (sample / 255.0).astype(numpy.float32), int(label) - def __init__(self, url, md5, meta_filename, train_filename, test_filename): - filename = os.path.split(url)[-1] - assert DATA_HOME is not None - filepath = os.path.join(DATA_HOME, md5) - if not os.path.exists(filepath): - os.makedirs(filepath) - - self.__full_file__ = os.path.join(filepath, filename) - self.__meta_filename__ = meta_filename - self.__train_filename__ = train_filename - self.__test_filename__ = test_filename - __download_file__(filename=self.__full_file__, url=url, md5=md5) - - def labels(self): - """ - labels get all dataset label in order. - :return: a list of label. - :rtype: list[string] - """ - with tarfile.open(self.__full_file__, mode='r') as f: - name = [ - each_item.name for each_item in f - if self.__meta_filename__ in each_item.name - ][0] - meta_f = f.extractfile(name) - meta = cPickle.load(meta_f) - for key in meta: - if 'label' in key: - return meta[key] - else: - raise RuntimeError("Unexpected branch.") - - def train(self): - """ - Train Reader - """ - return self.__read_batch__(self.__train_filename__) - - def test(self): - """ - Test Reader - """ - return self.__read_batch__(self.__test_filename__) - - def __read_batch__(self, sub_name): - with tarfile.open(self.__full_file__, mode='r') as f: + with tarfile.open(filename, mode='r') as f: names = (each_item.name for each_item in f if sub_name in each_item.name) for name in names: batch = cPickle.load(f.extractfile(name)) - for item in __read_one_batch__(batch): + for item in __read_one_batch_impl__(batch): yield item + return reader -class CIFAR10(CIFAR): - """ - CIFAR-10 dataset, images are classified in 10 classes. - """ - def __init__(self): - super(CIFAR10, self).__init__( - CIFAR10_URL, - CIFAR10_MD5, - meta_filename='batches.meta', - train_filename='data_batch', - test_filename='test_batch') +def download(url, md5): + filename = os.path.split(url)[-1] + assert DATA_HOME is not None + filepath = os.path.join(DATA_HOME, md5) + if not os.path.exists(filepath): + os.makedirs(filepath) + __full_file__ = os.path.join(filepath, filename) + def __file_ok__(): + if not os.path.exists(__full_file__): + return False + md5_hash = hashlib.md5() + with open(__full_file__, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + + return md5_hash.hexdigest() == md5 + + while not __file_ok__(): + response = urllib2.urlopen(url) + with open(__full_file__, mode='wb') as of: + shutil.copyfileobj(fsrc=response, fdst=of) + return __full_file__ + + +def cifar_100_train_creator(): + fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) + return __read_batch__(fn, 'train') -class CIFAR100(CIFAR): - """ - CIFAR-100 dataset, images are classified in 100 classes. - """ - def __init__(self): - super(CIFAR100, self).__init__( - CIFAR100_URL, - CIFAR100_MD5, - meta_filename='meta', - train_filename='train', - test_filename='test') +def cifar_100_test_creator(): + fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) + return __read_batch__(fn, 'test') def train_creator(): """ Default train reader creator. Use CIFAR-10 dataset. """ - cifar = CIFAR10() - return cifar.train + fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) + return __read_batch__(fn, 'data_batch') def test_creator(): """ Default test reader creator. Use CIFAR-10 dataset. """ - cifar = CIFAR10() - return cifar.test + fn = download(url=CIFAR10_URL, md5=CIFAR10_MD5) + return __read_batch__(fn, 'test_batch') -def unittest(label_count=100): - cifar = globals()["CIFAR%d" % label_count]() - assert len(cifar.labels()) == label_count - for _ in cifar.test(): +def unittest(): + for _ in train_creator()(): pass - for _ in cifar.train(): + for _ in test_creator()(): pass if __name__ == '__main__': - unittest(10) - unittest(100) + unittest() From de9012a50450da19a9227bc28d5bb042ac66fcf7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 27 Feb 2017 20:32:44 +0800 Subject: [PATCH 04/12] Add MovieLens Dataset --- python/paddle/v2/dataset/cifar.py | 35 +------- python/paddle/v2/dataset/config.py | 30 ++++++- python/paddle/v2/dataset/movielens.py | 120 ++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 32 deletions(-) create mode 100644 python/paddle/v2/dataset/movielens.py diff --git a/python/paddle/v2/dataset/cifar.py b/python/paddle/v2/dataset/cifar.py index 9a999de7e0..2ac71c6eff 100644 --- a/python/paddle/v2/dataset/cifar.py +++ b/python/paddle/v2/dataset/cifar.py @@ -5,16 +5,14 @@ URL: https://www.cs.toronto.edu/~kriz/cifar.html the default train_creator, test_creator used for CIFAR-10 dataset. """ -from config import DATA_HOME -import os -import hashlib -import urllib2 -import shutil -import tarfile import cPickle import itertools +import tarfile + import numpy +from config import download + __all__ = [ 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', 'test_creator' @@ -47,31 +45,6 @@ def __read_batch__(filename, sub_name): return reader -def download(url, md5): - filename = os.path.split(url)[-1] - assert DATA_HOME is not None - filepath = os.path.join(DATA_HOME, md5) - if not os.path.exists(filepath): - os.makedirs(filepath) - __full_file__ = os.path.join(filepath, filename) - - def __file_ok__(): - if not os.path.exists(__full_file__): - return False - md5_hash = hashlib.md5() - with open(__full_file__, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5_hash.update(chunk) - - return md5_hash.hexdigest() == md5 - - while not __file_ok__(): - response = urllib2.urlopen(url) - with open(__full_file__, mode='wb') as of: - shutil.copyfileobj(fsrc=response, fdst=of) - return __full_file__ - - def cifar_100_train_creator(): fn = download(url=CIFAR100_URL, md5=CIFAR100_MD5) return __read_batch__(fn, 'train') diff --git a/python/paddle/v2/dataset/config.py b/python/paddle/v2/dataset/config.py index 69e96d65ef..02a009f09c 100644 --- a/python/paddle/v2/dataset/config.py +++ b/python/paddle/v2/dataset/config.py @@ -1,8 +1,36 @@ +import hashlib import os +import shutil +import urllib2 -__all__ = ['DATA_HOME'] +__all__ = ['DATA_HOME', 'download'] DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set') if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) + + +def download(url, md5): + filename = os.path.split(url)[-1] + assert DATA_HOME is not None + filepath = os.path.join(DATA_HOME, md5) + if not os.path.exists(filepath): + os.makedirs(filepath) + __full_file__ = os.path.join(filepath, filename) + + def __file_ok__(): + if not os.path.exists(__full_file__): + return False + md5_hash = hashlib.md5() + with open(__full_file__, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + + return md5_hash.hexdigest() == md5 + + while not __file_ok__(): + response = urllib2.urlopen(url) + with open(__full_file__, mode='wb') as of: + shutil.copyfileobj(fsrc=response, fdst=of) + return __full_file__ diff --git a/python/paddle/v2/dataset/movielens.py b/python/paddle/v2/dataset/movielens.py new file mode 100644 index 0000000000..314329e91c --- /dev/null +++ b/python/paddle/v2/dataset/movielens.py @@ -0,0 +1,120 @@ +import zipfile +from config import download +import re +import random +import functools + +__all__ = ['train_creator', 'test_creator'] + + +class MovieInfo(object): + def __init__(self, index, categories, title): + self.index = int(index) + self.categories = categories + self.title = title + + def value(self): + return [ + self.index, [CATEGORIES_DICT[c] for c in self.categories], + [MOVIE_TITLE_DICT[w.lower()] for w in self.title.split()] + ] + + +class UserInfo(object): + def __init__(self, index, gender, age, job_id): + self.index = int(index) + self.is_male = gender == 'M' + self.age = [1, 18, 25, 35, 45, 50, 56].index(int(age)) + self.job_id = int(job_id) + + def value(self): + return [self.index, 0 if self.is_male else 1, self.age, self.job_id] + + +MOVIE_INFO = None +MOVIE_TITLE_DICT = None +CATEGORIES_DICT = None +USER_INFO = None + + +def __initialize_meta_info__(): + fn = download( + url='http://files.grouplens.org/datasets/movielens/ml-1m.zip', + md5='c4d9eecfca2ab87c1945afe126590906') + global MOVIE_INFO + if MOVIE_INFO is None: + pattern = re.compile(r'^(.*)\((\d+)\)$') + with zipfile.ZipFile(file=fn) as package: + for info in package.infolist(): + assert isinstance(info, zipfile.ZipInfo) + MOVIE_INFO = dict() + title_word_set = set() + categories_set = set() + with package.open('ml-1m/movies.dat') as movie_file: + for i, line in enumerate(movie_file): + movie_id, title, categories = line.strip().split('::') + categories = categories.split('|') + for c in categories: + categories_set.add(c) + title = pattern.match(title).group(1) + MOVIE_INFO[int(movie_id)] = MovieInfo( + index=movie_id, categories=categories, title=title) + for w in title.split(): + title_word_set.add(w.lower()) + + global MOVIE_TITLE_DICT + MOVIE_TITLE_DICT = dict() + for i, w in enumerate(title_word_set): + MOVIE_TITLE_DICT[w] = i + + global CATEGORIES_DICT + CATEGORIES_DICT = dict() + for i, c in enumerate(categories_set): + CATEGORIES_DICT[c] = i + + global USER_INFO + USER_INFO = dict() + with package.open('ml-1m/users.dat') as user_file: + for line in user_file: + uid, gender, age, job, _ = line.strip().split("::") + USER_INFO[int(uid)] = UserInfo( + index=uid, gender=gender, age=age, job_id=job) + return fn + + +def __reader__(rand_seed=0, test_ratio=0.1, is_test=False): + fn = __initialize_meta_info__() + rand = random.Random(x=rand_seed) + with zipfile.ZipFile(file=fn) as package: + with package.open('ml-1m/ratings.dat') as rating: + for line in rating: + if (rand.random() < test_ratio) == is_test: + uid, mov_id, rating, _ = line.strip().split("::") + uid = int(uid) + mov_id = int(mov_id) + rating = float(rating) * 2 - 5.0 + + mov = MOVIE_INFO[mov_id] + usr = USER_INFO[uid] + yield usr.value() + mov.value() + [[rating]] + + +def __reader_creator__(**kwargs): + return lambda: __reader__(**kwargs) + + +train_creator = functools.partial(__reader_creator__, is_test=False) +test_creator = functools.partial(__reader_creator__, is_test=True) + + +def unittest(): + for train_count, _ in enumerate(train_creator()()): + pass + for test_count, _ in enumerate(test_creator()()): + pass + + print train_count, test_count + + +if __name__ == '__main__': + unittest() From 7fe21602f6b767c64e058ab99ad71a1db6bd5d74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=9B=8A?= Date: Mon, 27 Feb 2017 14:43:43 -0800 Subject: [PATCH 05/12] Rename config.py into common.py --- python/paddle/v2/dataset/cifar.py | 2 +- python/paddle/v2/dataset/{config.py => common.py} | 0 python/paddle/v2/dataset/mnist.py | 2 +- python/paddle/v2/dataset/movielens.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename python/paddle/v2/dataset/{config.py => common.py} (100%) diff --git a/python/paddle/v2/dataset/cifar.py b/python/paddle/v2/dataset/cifar.py index 2ac71c6eff..accb32f117 100644 --- a/python/paddle/v2/dataset/cifar.py +++ b/python/paddle/v2/dataset/cifar.py @@ -11,7 +11,7 @@ import tarfile import numpy -from config import download +from common import download __all__ = [ 'cifar_100_train_creator', 'cifar_100_test_creator', 'train_creator', diff --git a/python/paddle/v2/dataset/config.py b/python/paddle/v2/dataset/common.py similarity index 100% rename from python/paddle/v2/dataset/config.py rename to python/paddle/v2/dataset/common.py diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index db84f37aa4..2f195bfb96 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -1,7 +1,7 @@ import sklearn.datasets.mldata import sklearn.model_selection import numpy -from config import DATA_HOME +from common import DATA_HOME __all__ = ['train_creator', 'test_creator'] diff --git a/python/paddle/v2/dataset/movielens.py b/python/paddle/v2/dataset/movielens.py index 314329e91c..dcffcff2f5 100644 --- a/python/paddle/v2/dataset/movielens.py +++ b/python/paddle/v2/dataset/movielens.py @@ -1,5 +1,5 @@ import zipfile -from config import download +from common import download import re import random import functools From b93722df95d3907782cdff034df360b79d1fd093 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 27 Feb 2017 23:07:33 +0000 Subject: [PATCH 06/12] Set data cache home directory to ~/.cache/paddle/dataset --- python/paddle/v2/dataset/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 02a009f09c..ae4a5383b0 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -5,7 +5,7 @@ import urllib2 __all__ = ['DATA_HOME', 'download'] -DATA_HOME = os.path.expanduser('~/.cache/paddle_data_set') +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) From 37e2b92089ed583ba9e73f615444c3b080cd1b63 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Mon, 27 Feb 2017 23:41:32 +0000 Subject: [PATCH 07/12] Add md5file into dataset/common.py, and unit test in tests/common_test.py --- python/paddle/v2/dataset/common.py | 13 +++++++++++-- python/paddle/v2/dataset/tests/common_test.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 python/paddle/v2/dataset/tests/common_test.py diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index ae4a5383b0..ff5ed76c0f 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -3,7 +3,7 @@ import os import shutil import urllib2 -__all__ = ['DATA_HOME', 'download'] +__all__ = ['DATA_HOME', 'download', 'md5file'] DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -11,7 +11,7 @@ if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) -def download(url, md5): +def download(url, package_name, md5): filename = os.path.split(url)[-1] assert DATA_HOME is not None filepath = os.path.join(DATA_HOME, md5) @@ -34,3 +34,12 @@ def download(url, md5): with open(__full_file__, mode='wb') as of: shutil.copyfileobj(fsrc=response, fdst=of) return __full_file__ + + +def md5file(fname): + hash_md5 = hashlib.md5() + f = open(fname, "rb") + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + f.close() + return hash_md5.hexdigest() diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py new file mode 100644 index 0000000000..d2f97f06de --- /dev/null +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -0,0 +1,16 @@ +import paddle.v2.dataset.common +import unittest +import tempfile + +class TestCommon(unittest.TestCase): + def test_md5file(self): + _, temp_path =tempfile.mkstemp() + f = open(temp_path, 'w') + f.write("Hello\n") + f.close() + self.assertEqual( + '09f7e02f1290be211da707a266f153b3', + paddle.v2.dataset.common.md5file(temp_path)) + +if __name__ == '__main__': + unittest.main() From 91115ab6de016806a5d3ad168b114f54b1eaac87 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 28 Feb 2017 00:17:52 +0000 Subject: [PATCH 08/12] Use module name and raw data filename as the local filename --- python/paddle/v2/dataset/common.py | 42 +++++++------------ python/paddle/v2/dataset/tests/common_test.py | 12 ++++-- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index ff5ed76c0f..b1831f38af 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -1,7 +1,7 @@ +import requests import hashlib import os import shutil -import urllib2 __all__ = ['DATA_HOME', 'download', 'md5file'] @@ -11,31 +11,6 @@ if not os.path.exists(DATA_HOME): os.makedirs(DATA_HOME) -def download(url, package_name, md5): - filename = os.path.split(url)[-1] - assert DATA_HOME is not None - filepath = os.path.join(DATA_HOME, md5) - if not os.path.exists(filepath): - os.makedirs(filepath) - __full_file__ = os.path.join(filepath, filename) - - def __file_ok__(): - if not os.path.exists(__full_file__): - return False - md5_hash = hashlib.md5() - with open(__full_file__, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b""): - md5_hash.update(chunk) - - return md5_hash.hexdigest() == md5 - - while not __file_ok__(): - response = urllib2.urlopen(url) - with open(__full_file__, mode='wb') as of: - shutil.copyfileobj(fsrc=response, fdst=of) - return __full_file__ - - def md5file(fname): hash_md5 = hashlib.md5() f = open(fname, "rb") @@ -43,3 +18,18 @@ def md5file(fname): hash_md5.update(chunk) f.close() return hash_md5.hexdigest() + + +def download(url, module_name, md5sum): + dirname = os.path.join(DATA_HOME, module_name) + if not os.path.exists(dirname): + os.makedirs(dirname) + + filename = os.path.join(dirname, url.split('/')[-1]) + if not (os.path.exists(filename) and md5file(filename) == md5sum): + # If file doesn't exist or MD5 doesn't match, then download. + r = requests.get(url, stream=True) + with open(filename, 'w') as f: + shutil.copyfileobj(r.raw, f) + + return filename diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index d2f97f06de..0672a46714 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -5,12 +5,18 @@ import tempfile class TestCommon(unittest.TestCase): def test_md5file(self): _, temp_path =tempfile.mkstemp() - f = open(temp_path, 'w') - f.write("Hello\n") - f.close() + with open(temp_path, 'w') as f: + f.write("Hello\n") self.assertEqual( '09f7e02f1290be211da707a266f153b3', paddle.v2.dataset.common.md5file(temp_path)) + def test_download(self): + yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460' + self.assertEqual( + paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460', + paddle.v2.dataset.common.download( + yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d')) + if __name__ == '__main__': unittest.main() From 792875e3eaa0467e40748f0ed97f022fe7fdcd0b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 28 Feb 2017 09:56:27 +0800 Subject: [PATCH 09/12] Lazy initialize mnist dataset. Fix unittest --- python/paddle/v2/dataset/mnist.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index db84f37aa4..faae818a5d 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -16,18 +16,29 @@ def __mnist_reader_creator__(data, target): TEST_SIZE = 10000 +X_train = None +X_test = None +y_train = None +y_test = None -data = sklearn.datasets.mldata.fetch_mldata( - "MNIST original", data_home=DATA_HOME) -X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - data.data, data.target, test_size=TEST_SIZE, random_state=0) + +def __initialize_dataset__(): + global X_train, X_test, y_train, y_test + if X_train is not None: + return + data = sklearn.datasets.mldata.fetch_mldata( + "MNIST original", data_home=DATA_HOME) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + data.data, data.target, test_size=TEST_SIZE, random_state=0) def train_creator(): + __initialize_dataset__() return __mnist_reader_creator__(X_train, y_train) def test_creator(): + __initialize_dataset__() return __mnist_reader_creator__(X_test, y_test) From d6c62e852d7788ba81e704323995874b85f89c3e Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 28 Feb 2017 02:26:30 +0000 Subject: [PATCH 10/12] Rewrite mnist.py and add mnist_test.py --- python/paddle/v2/dataset/mnist.py | 74 ++++++++++++++------ python/paddle/v2/dataset/tests/mnist_test.py | 27 +++++++ 2 files changed, 78 insertions(+), 23 deletions(-) create mode 100644 python/paddle/v2/dataset/tests/mnist_test.py diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 2f195bfb96..29fc20eae9 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -1,39 +1,67 @@ -import sklearn.datasets.mldata -import sklearn.model_selection +import paddle.v2.dataset.common +import subprocess import numpy -from common import DATA_HOME -__all__ = ['train_creator', 'test_creator'] +URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' +TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' +TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' -def __mnist_reader_creator__(data, target): - def reader(): - n_samples = data.shape[0] - for i in xrange(n_samples): - yield (data[i] / 255.0).astype(numpy.float32), int(target[i]) +TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' +TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688' + +TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz' +TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' - return reader +TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz' +TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' -TEST_SIZE = 10000 +def reader_creator(image_filename, label_filename, buffer_size): + def reader(): + # According to http://stackoverflow.com/a/38061619/724872, we + # cannot use standard package gzip here. + m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE) + m.stdout.read(16) # skip some magic bytes + + l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE) + l.stdout.read(8) # skip some magic bytes -data = sklearn.datasets.mldata.fetch_mldata( - "MNIST original", data_home=DATA_HOME) -X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - data.data, data.target, test_size=TEST_SIZE, random_state=0) + while True: + labels = numpy.fromfile( + l.stdout, 'ubyte', count=buffer_size + ).astype("int") + if labels.size != buffer_size: + break # numpy.fromfile returns empty slice after EOF. -def train_creator(): - return __mnist_reader_creator__(X_train, y_train) + images = numpy.fromfile( + m.stdout, 'ubyte', count=buffer_size * 28 * 28 + ).reshape((buffer_size, 28 * 28) + ).astype('float32') + images = images / 255.0 * 2.0 - 1.0 -def test_creator(): - return __mnist_reader_creator__(X_test, y_test) + for i in xrange(buffer_size): + yield images[i, :], labels[i] + m.terminate() + l.terminate() -def unittest(): - assert len(list(test_creator()())) == TEST_SIZE + return reader() +def train(): + return reader_creator( + paddle.v2.dataset.common.download( + TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5), + paddle.v2.dataset.common.download( + TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5), + 100) -if __name__ == '__main__': - unittest() +def test(): + return reader_creator( + paddle.v2.dataset.common.download( + TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5), + paddle.v2.dataset.common.download( + TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5), + 100) diff --git a/python/paddle/v2/dataset/tests/mnist_test.py b/python/paddle/v2/dataset/tests/mnist_test.py new file mode 100644 index 0000000000..23ed2eaba8 --- /dev/null +++ b/python/paddle/v2/dataset/tests/mnist_test.py @@ -0,0 +1,27 @@ +import paddle.v2.dataset.mnist +import unittest + +class TestMNIST(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + for l in reader: + self.assertEqual(l[0].size, 784) + self.assertEqual(l[1].size, 1) + self.assertLess(l[1], 10) + self.assertGreaterEqual(l[1], 0) + sum += 1 + return sum + + def test_train(self): + self.assertEqual( + self.check_reader(paddle.v2.dataset.mnist.train()), + 60000) + + def test_test(self): + self.assertEqual( + self.check_reader(paddle.v2.dataset.mnist.test()), + 10000) + + +if __name__ == '__main__': + unittest.main() From dcbfbb15338e0ca0f195e12ce0e0275995622ca1 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 28 Feb 2017 02:46:31 +0000 Subject: [PATCH 11/12] yapf format --- python/paddle/v2/dataset/mnist.py | 34 +++++++++---------- python/paddle/v2/dataset/tests/common_test.py | 9 ++--- python/paddle/v2/dataset/tests/mnist_test.py | 7 ++-- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 29fc20eae9..ec334d39e6 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -22,23 +22,21 @@ def reader_creator(image_filename, label_filename, buffer_size): # According to http://stackoverflow.com/a/38061619/724872, we # cannot use standard package gzip here. m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE) - m.stdout.read(16) # skip some magic bytes + m.stdout.read(16) # skip some magic bytes l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE) - l.stdout.read(8) # skip some magic bytes + l.stdout.read(8) # skip some magic bytes while True: labels = numpy.fromfile( - l.stdout, 'ubyte', count=buffer_size - ).astype("int") + l.stdout, 'ubyte', count=buffer_size).astype("int") if labels.size != buffer_size: - break # numpy.fromfile returns empty slice after EOF. + break # numpy.fromfile returns empty slice after EOF. images = numpy.fromfile( - m.stdout, 'ubyte', count=buffer_size * 28 * 28 - ).reshape((buffer_size, 28 * 28) - ).astype('float32') + m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape( + (buffer_size, 28 * 28)).astype('float32') images = images / 255.0 * 2.0 - 1.0 @@ -50,18 +48,18 @@ def reader_creator(image_filename, label_filename, buffer_size): return reader() + def train(): return reader_creator( - paddle.v2.dataset.common.download( - TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5), - paddle.v2.dataset.common.download( - TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5), - 100) + paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', + TRAIN_IMAGE_MD5), + paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', + TRAIN_LABEL_MD5), 100) + def test(): return reader_creator( - paddle.v2.dataset.common.download( - TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5), - paddle.v2.dataset.common.download( - TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5), - 100) + paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', + TEST_IMAGE_MD5), + paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', + TEST_LABEL_MD5), 100) diff --git a/python/paddle/v2/dataset/tests/common_test.py b/python/paddle/v2/dataset/tests/common_test.py index 0672a46714..7d8406171b 100644 --- a/python/paddle/v2/dataset/tests/common_test.py +++ b/python/paddle/v2/dataset/tests/common_test.py @@ -2,14 +2,14 @@ import paddle.v2.dataset.common import unittest import tempfile + class TestCommon(unittest.TestCase): def test_md5file(self): - _, temp_path =tempfile.mkstemp() + _, temp_path = tempfile.mkstemp() with open(temp_path, 'w') as f: f.write("Hello\n") - self.assertEqual( - '09f7e02f1290be211da707a266f153b3', - paddle.v2.dataset.common.md5file(temp_path)) + self.assertEqual('09f7e02f1290be211da707a266f153b3', + paddle.v2.dataset.common.md5file(temp_path)) def test_download(self): yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460' @@ -18,5 +18,6 @@ class TestCommon(unittest.TestCase): paddle.v2.dataset.common.download( yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d')) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/dataset/tests/mnist_test.py b/python/paddle/v2/dataset/tests/mnist_test.py index 23ed2eaba8..e4f0b33d52 100644 --- a/python/paddle/v2/dataset/tests/mnist_test.py +++ b/python/paddle/v2/dataset/tests/mnist_test.py @@ -1,6 +1,7 @@ import paddle.v2.dataset.mnist import unittest + class TestMNIST(unittest.TestCase): def check_reader(self, reader): sum = 0 @@ -14,13 +15,11 @@ class TestMNIST(unittest.TestCase): def test_train(self): self.assertEqual( - self.check_reader(paddle.v2.dataset.mnist.train()), - 60000) + self.check_reader(paddle.v2.dataset.mnist.train()), 60000) def test_test(self): self.assertEqual( - self.check_reader(paddle.v2.dataset.mnist.test()), - 10000) + self.check_reader(paddle.v2.dataset.mnist.test()), 10000) if __name__ == '__main__': From 6bc82c8eb8478e2fe0911a8e43ddc7ed3539a372 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Tue, 28 Feb 2017 02:56:48 +0000 Subject: [PATCH 12/12] Add __all__ to mnist.py --- python/paddle/v2/dataset/mnist.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index ec334d39e6..8ba11ca5ec 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -2,17 +2,16 @@ import paddle.v2.dataset.common import subprocess import numpy +__all__ = ['train', 'test'] + URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' - TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688' - TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz' TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' - TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz' TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'