parent
991b582efc
commit
f837eee724
@ -0,0 +1,46 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Dataset package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mnist
|
||||||
|
import imikolov
|
||||||
|
import imdb
|
||||||
|
import cifar
|
||||||
|
import movielens
|
||||||
|
import conll05
|
||||||
|
import uci_housing
|
||||||
|
import sentiment
|
||||||
|
import wmt14
|
||||||
|
import wmt16
|
||||||
|
import mq2007
|
||||||
|
import flowers
|
||||||
|
import voc2012
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'mnist',
|
||||||
|
'imikolov',
|
||||||
|
'imdb',
|
||||||
|
'cifar',
|
||||||
|
'movielens',
|
||||||
|
'conll05',
|
||||||
|
'sentiment'
|
||||||
|
'uci_housing',
|
||||||
|
'wmt14',
|
||||||
|
'wmt16',
|
||||||
|
'mq2007',
|
||||||
|
'flowers',
|
||||||
|
'voc2012',
|
||||||
|
]
|
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
CIFAR dataset.
|
||||||
|
|
||||||
|
This module will download dataset from
|
||||||
|
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into
|
||||||
|
paddle reader creators.
|
||||||
|
|
||||||
|
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
|
||||||
|
with 6000 images per class. There are 50000 training images and 10000 test
|
||||||
|
images.
|
||||||
|
|
||||||
|
The CIFAR-100 dataset is just like the CIFAR-10, except it has 100 classes
|
||||||
|
containing 600 images each. There are 500 training images and 100 testing
|
||||||
|
images per class.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import cPickle
|
||||||
|
import itertools
|
||||||
|
import numpy
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
|
||||||
|
|
||||||
|
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
|
||||||
|
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
|
||||||
|
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
|
||||||
|
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
|
||||||
|
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(filename, sub_name):
|
||||||
|
def read_batch(batch):
|
||||||
|
data = batch['data']
|
||||||
|
labels = batch.get('labels', batch.get('fine_labels', None))
|
||||||
|
assert labels is not None
|
||||||
|
for sample, label in itertools.izip(data, labels):
|
||||||
|
yield (sample / 255.0).astype(numpy.float32), int(label)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
with tarfile.open(filename, mode='r') as f:
|
||||||
|
names = (each_item.name for each_item in f
|
||||||
|
if sub_name in each_item.name)
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
batch = cPickle.load(f.extractfile(name))
|
||||||
|
for item in read_batch(batch):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def train100():
|
||||||
|
"""
|
||||||
|
CIFAR-100 training set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is image pixels in
|
||||||
|
[0, 1] and label in [0, 99].
|
||||||
|
|
||||||
|
:return: Training reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
|
||||||
|
'train')
|
||||||
|
|
||||||
|
|
||||||
|
def test100():
|
||||||
|
"""
|
||||||
|
CIFAR-100 test set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is image pixels in
|
||||||
|
[0, 1] and label in [0, 9].
|
||||||
|
|
||||||
|
:return: Test reader creator.
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
|
||||||
|
'test')
|
||||||
|
|
||||||
|
|
||||||
|
def train10():
|
||||||
|
"""
|
||||||
|
CIFAR-10 training set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is image pixels in
|
||||||
|
[0, 1] and label in [0, 9].
|
||||||
|
|
||||||
|
:return: Training reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
|
||||||
|
'data_batch')
|
||||||
|
|
||||||
|
|
||||||
|
def test10():
|
||||||
|
"""
|
||||||
|
CIFAR-10 test set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is image pixels in
|
||||||
|
[0, 1] and label in [0, 9].
|
||||||
|
|
||||||
|
:return: Test reader creator.
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
|
||||||
|
'test_batch')
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
|
||||||
|
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(path):
|
||||||
|
"""
|
||||||
|
Converts dataset to recordio format
|
||||||
|
"""
|
||||||
|
paddle.v2.dataset.common.convert(path, train100(), 1000, "cifar_train100")
|
||||||
|
paddle.v2.dataset.common.convert(path, test100(), 1000, "cifar_test100")
|
||||||
|
paddle.v2.dataset.common.convert(path, train10(), 1000, "cifar_train10")
|
||||||
|
paddle.v2.dataset.common.convert(path, test10(), 1000, "cifar_test10")
|
@ -0,0 +1,236 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import errno
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import importlib
|
||||||
|
import paddle.v2.dataset
|
||||||
|
import cPickle
|
||||||
|
import glob
|
||||||
|
import cPickle as pickle
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'DATA_HOME',
|
||||||
|
'download',
|
||||||
|
'md5file',
|
||||||
|
'split',
|
||||||
|
'cluster_files_reader',
|
||||||
|
'convert',
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
|
||||||
|
|
||||||
|
|
||||||
|
# When running unit tests, there could be multiple processes that
|
||||||
|
# trying to create DATA_HOME directory simultaneously, so we cannot
|
||||||
|
# use a if condition to check for the existence of the directory;
|
||||||
|
# instead, we use the filesystem as the synchronization mechanism by
|
||||||
|
# catching returned errors.
|
||||||
|
def must_mkdirs(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(DATA_HOME)
|
||||||
|
except OSError as exc:
|
||||||
|
if exc.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
must_mkdirs(DATA_HOME)
|
||||||
|
|
||||||
|
|
||||||
|
def md5file(fname):
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
f = open(fname, "rb")
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
f.close()
|
||||||
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, module_name, md5sum, save_name=None):
|
||||||
|
dirname = os.path.join(DATA_HOME, module_name)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
|
||||||
|
filename = os.path.join(dirname,
|
||||||
|
url.split('/')[-1]
|
||||||
|
if save_name is None else save_name)
|
||||||
|
|
||||||
|
retry = 0
|
||||||
|
retry_limit = 3
|
||||||
|
while not (os.path.exists(filename) and md5file(filename) == md5sum):
|
||||||
|
if os.path.exists(filename):
|
||||||
|
print "file md5", md5file(filename), md5sum
|
||||||
|
if retry < retry_limit:
|
||||||
|
retry += 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Cannot download {0} within retry limit {1}".
|
||||||
|
format(url, retry_limit))
|
||||||
|
print "Cache file %s not found, downloading %s" % (filename, url)
|
||||||
|
r = requests.get(url, stream=True)
|
||||||
|
total_length = r.headers.get('content-length')
|
||||||
|
|
||||||
|
if total_length is None:
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
shutil.copyfileobj(r.raw, f)
|
||||||
|
else:
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
dl = 0
|
||||||
|
total_length = int(total_length)
|
||||||
|
for data in r.iter_content(chunk_size=4096):
|
||||||
|
dl += len(data)
|
||||||
|
f.write(data)
|
||||||
|
done = int(50 * dl / total_length)
|
||||||
|
sys.stdout.write("\r[%s%s]" % ('=' * done,
|
||||||
|
' ' * (50 - done)))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_all():
|
||||||
|
for module_name in filter(lambda x: not x.startswith("__"),
|
||||||
|
dir(paddle.v2.dataset)):
|
||||||
|
if "fetch" in dir(
|
||||||
|
importlib.import_module("paddle.v2.dataset.%s" % module_name)):
|
||||||
|
getattr(
|
||||||
|
importlib.import_module("paddle.v2.dataset.%s" % module_name),
|
||||||
|
"fetch")()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_all_recordio(path):
|
||||||
|
for module_name in filter(lambda x: not x.startswith("__"),
|
||||||
|
dir(paddle.v2.dataset)):
|
||||||
|
if "convert" in dir(
|
||||||
|
importlib.import_module("paddle.v2.dataset.%s" % module_name)) and \
|
||||||
|
not module_name == "common":
|
||||||
|
ds_path = os.path.join(path, module_name)
|
||||||
|
must_mkdirs(ds_path)
|
||||||
|
getattr(
|
||||||
|
importlib.import_module("paddle.v2.dataset.%s" % module_name),
|
||||||
|
"convert")(ds_path)
|
||||||
|
|
||||||
|
|
||||||
|
def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
|
||||||
|
"""
|
||||||
|
you can call the function as:
|
||||||
|
|
||||||
|
split(paddle.v2.dataset.cifar.train10(), line_count=1000,
|
||||||
|
suffix="imikolov-train-%05d.pickle")
|
||||||
|
|
||||||
|
the output files as:
|
||||||
|
|
||||||
|
|-imikolov-train-00000.pickle
|
||||||
|
|-imikolov-train-00001.pickle
|
||||||
|
|- ...
|
||||||
|
|-imikolov-train-00480.pickle
|
||||||
|
|
||||||
|
:param reader: is a reader creator
|
||||||
|
:param line_count: line count for each file
|
||||||
|
:param suffix: the suffix for the output files, should contain "%d"
|
||||||
|
means the id for each file. Default is "%05d.pickle"
|
||||||
|
:param dumper: is a callable function that dump object to file, this
|
||||||
|
function will be called as dumper(obj, f) and obj is the object
|
||||||
|
will be dumped, f is a file object. Default is cPickle.dump.
|
||||||
|
"""
|
||||||
|
if not callable(dumper):
|
||||||
|
raise TypeError("dumper should be callable.")
|
||||||
|
lines = []
|
||||||
|
indx_f = 0
|
||||||
|
for i, d in enumerate(reader()):
|
||||||
|
lines.append(d)
|
||||||
|
if i >= line_count and i % line_count == 0:
|
||||||
|
with open(suffix % indx_f, "w") as f:
|
||||||
|
dumper(lines, f)
|
||||||
|
lines = []
|
||||||
|
indx_f += 1
|
||||||
|
if lines:
|
||||||
|
with open(suffix % indx_f, "w") as f:
|
||||||
|
dumper(lines, f)
|
||||||
|
|
||||||
|
|
||||||
|
def cluster_files_reader(files_pattern,
|
||||||
|
trainer_count,
|
||||||
|
trainer_id,
|
||||||
|
loader=cPickle.load):
|
||||||
|
"""
|
||||||
|
Create a reader that yield element from the given files, select
|
||||||
|
a file set according trainer count and trainer_id
|
||||||
|
|
||||||
|
:param files_pattern: the files which generating by split(...)
|
||||||
|
:param trainer_count: total trainer count
|
||||||
|
:param trainer_id: the trainer rank id
|
||||||
|
:param loader: is a callable function that load object from file, this
|
||||||
|
function will be called as loader(f) and f is a file object.
|
||||||
|
Default is cPickle.load
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
if not callable(loader):
|
||||||
|
raise TypeError("loader should be callable.")
|
||||||
|
file_list = glob.glob(files_pattern)
|
||||||
|
file_list.sort()
|
||||||
|
my_file_list = []
|
||||||
|
for idx, fn in enumerate(file_list):
|
||||||
|
if idx % trainer_count == trainer_id:
|
||||||
|
print "append file: %s" % fn
|
||||||
|
my_file_list.append(fn)
|
||||||
|
for fn in my_file_list:
|
||||||
|
with open(fn, "r") as f:
|
||||||
|
lines = loader(f)
|
||||||
|
for line in lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def convert(output_path, reader, line_count, name_prefix):
|
||||||
|
import recordio
|
||||||
|
"""
|
||||||
|
Convert data from reader to recordio format files.
|
||||||
|
|
||||||
|
:param output_path: directory in which output files will be saved.
|
||||||
|
:param reader: a data reader, from which the convert program will read
|
||||||
|
data instances.
|
||||||
|
:param name_prefix: the name prefix of generated files.
|
||||||
|
:param max_lines_to_shuffle: the max lines numbers to shuffle before
|
||||||
|
writing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert line_count >= 1
|
||||||
|
indx_f = 0
|
||||||
|
|
||||||
|
def write_data(indx_f, lines):
|
||||||
|
filename = "%s/%s-%05d" % (output_path, name_prefix, indx_f)
|
||||||
|
writer = recordio.writer(filename)
|
||||||
|
for l in lines:
|
||||||
|
# FIXME(Yancey1989):
|
||||||
|
# dumps with protocol: pickle.HIGHEST_PROTOCOL
|
||||||
|
writer.write(cPickle.dumps(l))
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for i, d in enumerate(reader()):
|
||||||
|
lines.append(d)
|
||||||
|
if i % line_count == 0 and i >= line_count:
|
||||||
|
write_data(indx_f, lines)
|
||||||
|
lines = []
|
||||||
|
indx_f += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
write_data(indx_f, lines)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This module will download dataset from
|
||||||
|
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
|
||||||
|
and parse train/test set intopaddle reader creators.
|
||||||
|
|
||||||
|
This set contains images of flowers belonging to 102 different categories.
|
||||||
|
The images were acquired by searching the web and taking pictures. There are a
|
||||||
|
minimum of 40 images for each category.
|
||||||
|
|
||||||
|
The database was used in:
|
||||||
|
|
||||||
|
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
|
||||||
|
number of classes.Proceedings of the Indian Conference on Computer Vision,
|
||||||
|
Graphics and Image Processing (2008)
|
||||||
|
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import cPickle
|
||||||
|
import itertools
|
||||||
|
import functools
|
||||||
|
from common import download
|
||||||
|
import tarfile
|
||||||
|
import scipy.io as scio
|
||||||
|
from paddle.v2.image import *
|
||||||
|
from paddle.v2.reader import *
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
__all__ = ['train', 'test', 'valid']
|
||||||
|
|
||||||
|
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
|
||||||
|
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
|
||||||
|
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
|
||||||
|
DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118'
|
||||||
|
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
|
||||||
|
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
|
||||||
|
# In official 'readme', tstid is the flag of test data
|
||||||
|
# and trnid is the flag of train data. But test data is more than train data.
|
||||||
|
# So we exchange the train data and test data.
|
||||||
|
TRAIN_FLAG = 'tstid'
|
||||||
|
TEST_FLAG = 'trnid'
|
||||||
|
VALID_FLAG = 'valid'
|
||||||
|
|
||||||
|
|
||||||
|
def default_mapper(is_train, sample):
|
||||||
|
'''
|
||||||
|
map image bytes data to type needed by model input layer
|
||||||
|
'''
|
||||||
|
img, label = sample
|
||||||
|
img = load_image_bytes(img)
|
||||||
|
img = simple_transform(
|
||||||
|
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
|
||||||
|
return img.flatten().astype('float32'), label
|
||||||
|
|
||||||
|
|
||||||
|
train_mapper = functools.partial(default_mapper, True)
|
||||||
|
test_mapper = functools.partial(default_mapper, False)
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(data_file,
|
||||||
|
label_file,
|
||||||
|
setid_file,
|
||||||
|
dataset_name,
|
||||||
|
mapper,
|
||||||
|
buffered_size=1024,
|
||||||
|
use_xmap=True):
|
||||||
|
'''
|
||||||
|
1. read images from tar file and
|
||||||
|
merge images into batch files in 102flowers.tgz_batch/
|
||||||
|
2. get a reader to read sample from batch file
|
||||||
|
|
||||||
|
:param data_file: downloaded data file
|
||||||
|
:type data_file: string
|
||||||
|
:param label_file: downloaded label file
|
||||||
|
:type label_file: string
|
||||||
|
:param setid_file: downloaded setid file containing information
|
||||||
|
about how to split dataset
|
||||||
|
:type setid_file: string
|
||||||
|
:param dataset_name: data set name (tstid|trnid|valid)
|
||||||
|
:type dataset_name: string
|
||||||
|
:param mapper: a function to map image bytes data to type
|
||||||
|
needed by model input layer
|
||||||
|
:type mapper: callable
|
||||||
|
:param buffered_size: the size of buffer used to process images
|
||||||
|
:type buffered_size: int
|
||||||
|
:return: data reader
|
||||||
|
:rtype: callable
|
||||||
|
'''
|
||||||
|
labels = scio.loadmat(label_file)['labels'][0]
|
||||||
|
indexes = scio.loadmat(setid_file)[dataset_name][0]
|
||||||
|
img2label = {}
|
||||||
|
for i in indexes:
|
||||||
|
img = "jpg/image_%05d.jpg" % i
|
||||||
|
img2label[img] = labels[i - 1]
|
||||||
|
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
for file in open(file_list):
|
||||||
|
file = file.strip()
|
||||||
|
batch = None
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
batch = cPickle.load(f)
|
||||||
|
data = batch['data']
|
||||||
|
labels = batch['label']
|
||||||
|
for sample, label in itertools.izip(data, batch['label']):
|
||||||
|
yield sample, int(label) - 1
|
||||||
|
|
||||||
|
if use_xmap:
|
||||||
|
return xmap_readers(mapper, reader, cpu_count(), buffered_size)
|
||||||
|
else:
|
||||||
|
return map_readers(mapper, reader)
|
||||||
|
|
||||||
|
|
||||||
|
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
|
||||||
|
'''
|
||||||
|
Create flowers training set reader.
|
||||||
|
It returns a reader, each sample in the reader is
|
||||||
|
image pixels in [0, 1] and label in [1, 102]
|
||||||
|
translated from original color image by steps:
|
||||||
|
1. resize to 256*256
|
||||||
|
2. random crop to 224*224
|
||||||
|
3. flatten
|
||||||
|
:param mapper: a function to map sample.
|
||||||
|
:type mapper: callable
|
||||||
|
:param buffered_size: the size of buffer used to process images
|
||||||
|
:type buffered_size: int
|
||||||
|
:return: train data reader
|
||||||
|
:rtype: callable
|
||||||
|
'''
|
||||||
|
return reader_creator(
|
||||||
|
download(DATA_URL, 'flowers', DATA_MD5),
|
||||||
|
download(LABEL_URL, 'flowers', LABEL_MD5),
|
||||||
|
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
|
||||||
|
buffered_size, use_xmap)
|
||||||
|
|
||||||
|
|
||||||
|
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
|
||||||
|
'''
|
||||||
|
Create flowers test set reader.
|
||||||
|
It returns a reader, each sample in the reader is
|
||||||
|
image pixels in [0, 1] and label in [1, 102]
|
||||||
|
translated from original color image by steps:
|
||||||
|
1. resize to 256*256
|
||||||
|
2. random crop to 224*224
|
||||||
|
3. flatten
|
||||||
|
:param mapper: a function to map sample.
|
||||||
|
:type mapper: callable
|
||||||
|
:param buffered_size: the size of buffer used to process images
|
||||||
|
:type buffered_size: int
|
||||||
|
:return: test data reader
|
||||||
|
:rtype: callable
|
||||||
|
'''
|
||||||
|
return reader_creator(
|
||||||
|
download(DATA_URL, 'flowers', DATA_MD5),
|
||||||
|
download(LABEL_URL, 'flowers', LABEL_MD5),
|
||||||
|
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
|
||||||
|
buffered_size, use_xmap)
|
||||||
|
|
||||||
|
|
||||||
|
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
|
||||||
|
'''
|
||||||
|
Create flowers validation set reader.
|
||||||
|
It returns a reader, each sample in the reader is
|
||||||
|
image pixels in [0, 1] and label in [1, 102]
|
||||||
|
translated from original color image by steps:
|
||||||
|
1. resize to 256*256
|
||||||
|
2. random crop to 224*224
|
||||||
|
3. flatten
|
||||||
|
:param mapper: a function to map sample.
|
||||||
|
:type mapper: callable
|
||||||
|
:param buffered_size: the size of buffer used to process images
|
||||||
|
:type buffered_size: int
|
||||||
|
:return: test data reader
|
||||||
|
:rtype: callable
|
||||||
|
'''
|
||||||
|
return reader_creator(
|
||||||
|
download(DATA_URL, 'flowers', DATA_MD5),
|
||||||
|
download(LABEL_URL, 'flowers', LABEL_MD5),
|
||||||
|
download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
|
||||||
|
buffered_size, use_xmap)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
download(DATA_URL, 'flowers', DATA_MD5)
|
||||||
|
download(LABEL_URL, 'flowers', LABEL_MD5)
|
||||||
|
download(SETID_URL, 'flowers', SETID_MD5)
|
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
IMDB dataset.
|
||||||
|
|
||||||
|
This module downloads IMDB dataset from
|
||||||
|
http://ai.stanford.edu/%7Eamaas/data/sentiment/. This dataset contains a set
|
||||||
|
of 25,000 highly polar movie reviews for training, and 25,000 for testing.
|
||||||
|
Besides, this module also provides API for building dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
import collections
|
||||||
|
import tarfile
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
|
||||||
|
__all__ = ['build_dict', 'train', 'test', 'convert']
|
||||||
|
|
||||||
|
URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
|
||||||
|
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(pattern):
|
||||||
|
"""
|
||||||
|
Read files that match the given pattern. Tokenize and yield each file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with tarfile.open(paddle.v2.dataset.common.download(URL, 'imdb',
|
||||||
|
MD5)) as tarf:
|
||||||
|
# Note that we should use tarfile.next(), which does
|
||||||
|
# sequential access of member files, other than
|
||||||
|
# tarfile.extractfile, which does random access and might
|
||||||
|
# destroy hard disks.
|
||||||
|
tf = tarf.next()
|
||||||
|
while tf != None:
|
||||||
|
if bool(pattern.match(tf.name)):
|
||||||
|
# newline and punctuations removal and ad-hoc tokenization.
|
||||||
|
yield tarf.extractfile(tf).read().rstrip("\n\r").translate(
|
||||||
|
None, string.punctuation).lower().split()
|
||||||
|
tf = tarf.next()
|
||||||
|
|
||||||
|
|
||||||
|
def build_dict(pattern, cutoff):
|
||||||
|
"""
|
||||||
|
Build a word dictionary from the corpus. Keys of the dictionary are words,
|
||||||
|
and values are zero-based IDs of these words.
|
||||||
|
"""
|
||||||
|
word_freq = collections.defaultdict(int)
|
||||||
|
for doc in tokenize(pattern):
|
||||||
|
for word in doc:
|
||||||
|
word_freq[word] += 1
|
||||||
|
|
||||||
|
# Not sure if we should prune less-frequent words here.
|
||||||
|
word_freq = filter(lambda x: x[1] > cutoff, word_freq.items())
|
||||||
|
|
||||||
|
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
||||||
|
words, _ = list(zip(*dictionary))
|
||||||
|
word_idx = dict(zip(words, xrange(len(words))))
|
||||||
|
word_idx['<unk>'] = len(words)
|
||||||
|
return word_idx
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(pos_pattern, neg_pattern, word_idx):
|
||||||
|
UNK = word_idx['<unk>']
|
||||||
|
INS = []
|
||||||
|
|
||||||
|
def load(pattern, out, label):
|
||||||
|
for doc in tokenize(pattern):
|
||||||
|
out.append(([word_idx.get(w, UNK) for w in doc], label))
|
||||||
|
|
||||||
|
load(pos_pattern, INS, 0)
|
||||||
|
load(neg_pattern, INS, 1)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
for doc, label in INS:
|
||||||
|
yield doc, label
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def train(word_idx):
|
||||||
|
"""
|
||||||
|
IMDB training set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is an zero-based ID
|
||||||
|
sequence and label in [0, 1].
|
||||||
|
|
||||||
|
:param word_idx: word dictionary
|
||||||
|
:type word_idx: dict
|
||||||
|
:return: Training reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
re.compile("aclImdb/train/pos/.*\.txt$"),
|
||||||
|
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def test(word_idx):
|
||||||
|
"""
|
||||||
|
IMDB test set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is an zero-based ID
|
||||||
|
sequence and label in [0, 1].
|
||||||
|
|
||||||
|
:param word_idx: word dictionary
|
||||||
|
:type word_idx: dict
|
||||||
|
:return: Test reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
re.compile("aclImdb/test/pos/.*\.txt$"),
|
||||||
|
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def word_dict():
|
||||||
|
"""
|
||||||
|
Build a word dictionary from the corpus.
|
||||||
|
|
||||||
|
:return: Word dictionary
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
return build_dict(
|
||||||
|
re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
paddle.v2.dataset.common.download(URL, 'imdb', MD5)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(path):
|
||||||
|
"""
|
||||||
|
Converts dataset to recordio format
|
||||||
|
"""
|
||||||
|
w = word_dict()
|
||||||
|
paddle.v2.dataset.common.convert(path, lambda: train(w), 1000, "imdb_train")
|
||||||
|
paddle.v2.dataset.common.convert(path, lambda: test(w), 1000, "imdb_test")
|
@ -0,0 +1,161 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
imikolov's simple dataset.
|
||||||
|
|
||||||
|
This module will download dataset from
|
||||||
|
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
|
||||||
|
into paddle reader creators.
|
||||||
|
"""
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
import collections
|
||||||
|
import tarfile
|
||||||
|
|
||||||
|
__all__ = ['train', 'test', 'build_dict', 'convert']
|
||||||
|
|
||||||
|
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
|
||||||
|
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
|
||||||
|
|
||||||
|
|
||||||
|
class DataType(object):
|
||||||
|
NGRAM = 1
|
||||||
|
SEQ = 2
|
||||||
|
|
||||||
|
|
||||||
|
def word_count(f, word_freq=None):
|
||||||
|
if word_freq is None:
|
||||||
|
word_freq = collections.defaultdict(int)
|
||||||
|
|
||||||
|
for l in f:
|
||||||
|
for w in l.strip().split():
|
||||||
|
word_freq[w] += 1
|
||||||
|
word_freq['<s>'] += 1
|
||||||
|
word_freq['<e>'] += 1
|
||||||
|
|
||||||
|
return word_freq
|
||||||
|
|
||||||
|
|
||||||
|
def build_dict(min_word_freq=50):
|
||||||
|
"""
|
||||||
|
Build a word dictionary from the corpus, Keys of the dictionary are words,
|
||||||
|
and values are zero-based IDs of these words.
|
||||||
|
"""
|
||||||
|
train_filename = './simple-examples/data/ptb.train.txt'
|
||||||
|
test_filename = './simple-examples/data/ptb.valid.txt'
|
||||||
|
with tarfile.open(
|
||||||
|
paddle.v2.dataset.common.download(
|
||||||
|
paddle.v2.dataset.imikolov.URL, 'imikolov',
|
||||||
|
paddle.v2.dataset.imikolov.MD5)) as tf:
|
||||||
|
trainf = tf.extractfile(train_filename)
|
||||||
|
testf = tf.extractfile(test_filename)
|
||||||
|
word_freq = word_count(testf, word_count(trainf))
|
||||||
|
if '<unk>' in word_freq:
|
||||||
|
# remove <unk> for now, since we will set it as last index
|
||||||
|
del word_freq['<unk>']
|
||||||
|
|
||||||
|
word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items())
|
||||||
|
|
||||||
|
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
||||||
|
words, _ = list(zip(*word_freq_sorted))
|
||||||
|
word_idx = dict(zip(words, xrange(len(words))))
|
||||||
|
word_idx['<unk>'] = len(words)
|
||||||
|
|
||||||
|
return word_idx
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(filename, word_idx, n, data_type):
|
||||||
|
def reader():
|
||||||
|
with tarfile.open(
|
||||||
|
paddle.v2.dataset.common.download(
|
||||||
|
paddle.v2.dataset.imikolov.URL, 'imikolov',
|
||||||
|
paddle.v2.dataset.imikolov.MD5)) as tf:
|
||||||
|
f = tf.extractfile(filename)
|
||||||
|
|
||||||
|
UNK = word_idx['<unk>']
|
||||||
|
for l in f:
|
||||||
|
if DataType.NGRAM == data_type:
|
||||||
|
assert n > -1, 'Invalid gram length'
|
||||||
|
l = ['<s>'] + l.strip().split() + ['<e>']
|
||||||
|
if len(l) >= n:
|
||||||
|
l = [word_idx.get(w, UNK) for w in l]
|
||||||
|
for i in range(n, len(l) + 1):
|
||||||
|
yield tuple(l[i - n:i])
|
||||||
|
elif DataType.SEQ == data_type:
|
||||||
|
l = l.strip().split()
|
||||||
|
l = [word_idx.get(w, UNK) for w in l]
|
||||||
|
src_seq = [word_idx['<s>']] + l
|
||||||
|
trg_seq = l + [word_idx['<e>']]
|
||||||
|
if n > 0 and len(src_seq) > n: continue
|
||||||
|
yield src_seq, trg_seq
|
||||||
|
else:
|
||||||
|
assert False, 'Unknow data type'
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def train(word_idx, n, data_type=DataType.NGRAM):
|
||||||
|
"""
|
||||||
|
imikolov training set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is a word ID
|
||||||
|
tuple.
|
||||||
|
|
||||||
|
:param word_idx: word dictionary
|
||||||
|
:type word_idx: dict
|
||||||
|
:param n: sliding window size if type is ngram, otherwise max length of sequence
|
||||||
|
:type n: int
|
||||||
|
:param data_type: data type (ngram or sequence)
|
||||||
|
:type data_type: member variable of DataType (NGRAM or SEQ)
|
||||||
|
:return: Training reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
|
||||||
|
data_type)
|
||||||
|
|
||||||
|
|
||||||
|
def test(word_idx, n, data_type=DataType.NGRAM):
|
||||||
|
"""
|
||||||
|
imikolov test set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is a word ID
|
||||||
|
tuple.
|
||||||
|
|
||||||
|
:param word_idx: word dictionary
|
||||||
|
:type word_idx: dict
|
||||||
|
:param n: sliding window size if type is ngram, otherwise max length of sequence
|
||||||
|
:type n: int
|
||||||
|
:param data_type: data type (ngram or sequence)
|
||||||
|
:type data_type: member variable of DataType (NGRAM or SEQ)
|
||||||
|
:return: Test reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
|
||||||
|
data_type)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
paddle.v2.dataset.common.download(URL, "imikolov", MD5)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(path):
|
||||||
|
"""
|
||||||
|
Converts dataset to recordio format
|
||||||
|
"""
|
||||||
|
N = 5
|
||||||
|
word_dict = build_dict()
|
||||||
|
paddle.v2.dataset.common.convert(path,
|
||||||
|
train(word_dict, N), 1000,
|
||||||
|
"imikolov_train")
|
||||||
|
paddle.v2.dataset.common.convert(path,
|
||||||
|
test(word_dict, N), 1000, "imikolov_test")
|
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
MNIST dataset.
|
||||||
|
|
||||||
|
This module will download dataset from http://yann.lecun.com/exdb/mnist/ and
|
||||||
|
parse training set and test set into paddle reader creators.
|
||||||
|
"""
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
import subprocess
|
||||||
|
import numpy
|
||||||
|
import platform
|
||||||
|
__all__ = ['train', 'test', 'convert']
|
||||||
|
|
||||||
|
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
|
||||||
|
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
|
||||||
|
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
|
||||||
|
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
|
||||||
|
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
|
||||||
|
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
|
||||||
|
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
|
||||||
|
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
|
||||||
|
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(image_filename, label_filename, buffer_size):
|
||||||
|
def reader():
|
||||||
|
if platform.system() == 'Darwin':
|
||||||
|
zcat_cmd = 'gzcat'
|
||||||
|
elif platform.system() == 'Linux':
|
||||||
|
zcat_cmd = 'zcat'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# According to http://stackoverflow.com/a/38061619/724872, we
|
||||||
|
# cannot use standard package gzip here.
|
||||||
|
m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE)
|
||||||
|
m.stdout.read(16) # skip some magic bytes
|
||||||
|
|
||||||
|
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
|
||||||
|
l.stdout.read(8) # skip some magic bytes
|
||||||
|
|
||||||
|
try: # reader could be break.
|
||||||
|
while True:
|
||||||
|
labels = numpy.fromfile(
|
||||||
|
l.stdout, 'ubyte', count=buffer_size).astype("int")
|
||||||
|
|
||||||
|
if labels.size != buffer_size:
|
||||||
|
break # numpy.fromfile returns empty slice after EOF.
|
||||||
|
|
||||||
|
images = numpy.fromfile(
|
||||||
|
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
|
||||||
|
(buffer_size, 28 * 28)).astype('float32')
|
||||||
|
|
||||||
|
images = images / 255.0 * 2.0 - 1.0
|
||||||
|
|
||||||
|
for i in xrange(buffer_size):
|
||||||
|
yield images[i, :], int(labels[i])
|
||||||
|
finally:
|
||||||
|
m.terminate()
|
||||||
|
l.terminate()
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""
|
||||||
|
MNIST training set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is image pixels in
|
||||||
|
[0, 1] and label in [0, 9].
|
||||||
|
|
||||||
|
:return: Training reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist',
|
||||||
|
TRAIN_IMAGE_MD5),
|
||||||
|
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist',
|
||||||
|
TRAIN_LABEL_MD5), 100)
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
MNIST test set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is image pixels in
|
||||||
|
[0, 1] and label in [0, 9].
|
||||||
|
|
||||||
|
:return: Test reader creator.
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
return reader_creator(
|
||||||
|
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist',
|
||||||
|
TEST_IMAGE_MD5),
|
||||||
|
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist',
|
||||||
|
TEST_LABEL_MD5), 100)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
|
||||||
|
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
|
||||||
|
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
|
||||||
|
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(path):
|
||||||
|
"""
|
||||||
|
Converts dataset to recordio format
|
||||||
|
"""
|
||||||
|
paddle.v2.dataset.common.convert(path, train(), 1000, "minist_train")
|
||||||
|
paddle.v2.dataset.common.convert(path, test(), 1000, "minist_test")
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,141 @@
|
|||||||
|
# /usr/bin/env python
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
The script fetch and preprocess movie_reviews data set that provided by NLTK
|
||||||
|
|
||||||
|
TODO(yuyang18): Complete dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
from nltk.corpus import movie_reviews
|
||||||
|
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
|
||||||
|
__all__ = ['train', 'test', 'get_word_dict', 'convert']
|
||||||
|
NUM_TRAINING_INSTANCES = 1600
|
||||||
|
NUM_TOTAL_INSTANCES = 2000
|
||||||
|
|
||||||
|
|
||||||
|
def download_data_if_not_yet():
|
||||||
|
"""
|
||||||
|
Download the data set, if the data set is not download.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# make sure that nltk can find the data
|
||||||
|
if paddle.v2.dataset.common.DATA_HOME not in nltk.data.path:
|
||||||
|
nltk.data.path.append(paddle.v2.dataset.common.DATA_HOME)
|
||||||
|
movie_reviews.categories()
|
||||||
|
except LookupError:
|
||||||
|
print "Downloading movie_reviews data set, please wait....."
|
||||||
|
nltk.download(
|
||||||
|
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
|
||||||
|
print "Download data set success....."
|
||||||
|
print "Path is " + nltk.data.find('corpora/movie_reviews').path
|
||||||
|
|
||||||
|
|
||||||
|
def get_word_dict():
|
||||||
|
"""
|
||||||
|
Sorted the words by the frequency of words which occur in sample
|
||||||
|
:return:
|
||||||
|
words_freq_sorted
|
||||||
|
"""
|
||||||
|
words_freq_sorted = list()
|
||||||
|
word_freq_dict = collections.defaultdict(int)
|
||||||
|
download_data_if_not_yet()
|
||||||
|
|
||||||
|
for category in movie_reviews.categories():
|
||||||
|
for field in movie_reviews.fileids(category):
|
||||||
|
for words in movie_reviews.words(field):
|
||||||
|
word_freq_dict[words] += 1
|
||||||
|
words_sort_list = word_freq_dict.items()
|
||||||
|
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
|
||||||
|
for index, word in enumerate(words_sort_list):
|
||||||
|
words_freq_sorted.append((word[0], index))
|
||||||
|
return words_freq_sorted
|
||||||
|
|
||||||
|
|
||||||
|
def sort_files():
|
||||||
|
"""
|
||||||
|
Sorted the sample for cross reading the sample
|
||||||
|
:return:
|
||||||
|
files_list
|
||||||
|
"""
|
||||||
|
files_list = list()
|
||||||
|
neg_file_list = movie_reviews.fileids('neg')
|
||||||
|
pos_file_list = movie_reviews.fileids('pos')
|
||||||
|
files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list)))
|
||||||
|
return files_list
|
||||||
|
|
||||||
|
|
||||||
|
def load_sentiment_data():
|
||||||
|
"""
|
||||||
|
Load the data set
|
||||||
|
:return:
|
||||||
|
data_set
|
||||||
|
"""
|
||||||
|
data_set = list()
|
||||||
|
download_data_if_not_yet()
|
||||||
|
words_ids = dict(get_word_dict())
|
||||||
|
for sample_file in sort_files():
|
||||||
|
words_list = list()
|
||||||
|
category = 0 if 'neg' in sample_file else 1
|
||||||
|
for word in movie_reviews.words(sample_file):
|
||||||
|
words_list.append(words_ids[word.lower()])
|
||||||
|
data_set.append((words_list, category))
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(data):
|
||||||
|
"""
|
||||||
|
Reader creator, generate an iterator for data set
|
||||||
|
:param data:
|
||||||
|
train data set or test data set
|
||||||
|
"""
|
||||||
|
for each in data:
|
||||||
|
yield each[0], each[1]
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""
|
||||||
|
Default training set reader creator
|
||||||
|
"""
|
||||||
|
data_set = load_sentiment_data()
|
||||||
|
return reader_creator(data_set[0:NUM_TRAINING_INSTANCES])
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
Default test set reader creator
|
||||||
|
"""
|
||||||
|
data_set = load_sentiment_data()
|
||||||
|
return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
nltk.download(
|
||||||
|
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(path):
|
||||||
|
"""
|
||||||
|
Converts dataset to recordio format
|
||||||
|
"""
|
||||||
|
paddle.v2.dataset.common.convert(path, train, 1000, "sentiment_train")
|
||||||
|
paddle.v2.dataset.common.convert(path, test, 1000, "sentiment_test")
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.cifar
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestCIFAR(unittest.TestCase):
|
||||||
|
def check_reader(self, reader):
|
||||||
|
sum = 0
|
||||||
|
label = 0
|
||||||
|
for l in reader():
|
||||||
|
self.assertEqual(l[0].size, 3072)
|
||||||
|
if l[1] > label:
|
||||||
|
label = l[1]
|
||||||
|
sum += 1
|
||||||
|
return sum, label
|
||||||
|
|
||||||
|
def test_test10(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.cifar.test10())
|
||||||
|
self.assertEqual(instances, 10000)
|
||||||
|
self.assertEqual(max_label_value, 9)
|
||||||
|
|
||||||
|
def test_train10(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.cifar.train10())
|
||||||
|
self.assertEqual(instances, 50000)
|
||||||
|
self.assertEqual(max_label_value, 9)
|
||||||
|
|
||||||
|
def test_test100(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.cifar.test100())
|
||||||
|
self.assertEqual(instances, 10000)
|
||||||
|
self.assertEqual(max_label_value, 99)
|
||||||
|
|
||||||
|
def test_train100(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.cifar.train100())
|
||||||
|
self.assertEqual(instances, 50000)
|
||||||
|
self.assertEqual(max_label_value, 99)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
import unittest
|
||||||
|
import tempfile
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommon(unittest.TestCase):
|
||||||
|
def test_md5file(self):
|
||||||
|
_, temp_path = tempfile.mkstemp()
|
||||||
|
with open(temp_path, 'w') as f:
|
||||||
|
f.write("Hello\n")
|
||||||
|
self.assertEqual('09f7e02f1290be211da707a266f153b3',
|
||||||
|
paddle.v2.dataset.common.md5file(temp_path))
|
||||||
|
|
||||||
|
def test_download(self):
|
||||||
|
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
|
||||||
|
self.assertEqual(
|
||||||
|
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
|
||||||
|
paddle.v2.dataset.common.download(
|
||||||
|
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
|
||||||
|
|
||||||
|
def test_split(self):
|
||||||
|
def test_reader():
|
||||||
|
def reader():
|
||||||
|
for x in xrange(10):
|
||||||
|
yield x
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
_, temp_path = tempfile.mkstemp()
|
||||||
|
paddle.v2.dataset.common.split(
|
||||||
|
test_reader(), 4, suffix=temp_path + '/test-%05d.pickle')
|
||||||
|
files = glob.glob(temp_path + '/test-%05d.pickle')
|
||||||
|
self.assertEqual(len(files), 3)
|
||||||
|
|
||||||
|
def test_cluster_file_reader(self):
|
||||||
|
_, temp_path = tempfile.mkstemp()
|
||||||
|
for x in xrange(5):
|
||||||
|
with open(temp_path + '/%05d.test' % x) as f:
|
||||||
|
f.write('%d\n' % x)
|
||||||
|
reader = paddle.v2.dataset.common.cluster_files_reader(
|
||||||
|
temp_path + '/*.test', 5, 0)
|
||||||
|
for idx, e in enumerate(reader()):
|
||||||
|
self.assertEqual(e, str("0"))
|
||||||
|
|
||||||
|
def test_convert(self):
|
||||||
|
record_num = 10
|
||||||
|
num_shards = 4
|
||||||
|
|
||||||
|
def test_reader():
|
||||||
|
def reader():
|
||||||
|
for x in xrange(record_num):
|
||||||
|
yield x
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
path = tempfile.mkdtemp()
|
||||||
|
paddle.v2.dataset.common.convert(path,
|
||||||
|
test_reader(), num_shards,
|
||||||
|
'random_images')
|
||||||
|
|
||||||
|
files = glob.glob(path + '/random_images-*')
|
||||||
|
self.assertEqual(len(files), num_shards)
|
||||||
|
|
||||||
|
recs = []
|
||||||
|
for i in range(0, num_shards):
|
||||||
|
n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1)
|
||||||
|
r = recordio.reader(n)
|
||||||
|
while True:
|
||||||
|
d = r.read()
|
||||||
|
if d is None:
|
||||||
|
break
|
||||||
|
recs.append(d)
|
||||||
|
|
||||||
|
recs.sort()
|
||||||
|
self.assertEqual(total, record_num)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.flowers
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlowers(unittest.TestCase):
|
||||||
|
def check_reader(self, reader):
|
||||||
|
sum = 0
|
||||||
|
label = 0
|
||||||
|
size = 224 * 224 * 3
|
||||||
|
for l in reader():
|
||||||
|
self.assertEqual(l[0].size, size)
|
||||||
|
if l[1] > label:
|
||||||
|
label = l[1]
|
||||||
|
sum += 1
|
||||||
|
return sum, label
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.flowers.train())
|
||||||
|
self.assertEqual(instances, 6149)
|
||||||
|
self.assertEqual(max_label_value, 102)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.flowers.test())
|
||||||
|
self.assertEqual(instances, 1020)
|
||||||
|
self.assertEqual(max_label_value, 102)
|
||||||
|
|
||||||
|
def test_valid(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.flowers.valid())
|
||||||
|
self.assertEqual(instances, 1020)
|
||||||
|
self.assertEqual(max_label_value, 102)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.imdb
|
||||||
|
import unittest
|
||||||
|
import re
|
||||||
|
|
||||||
|
TRAIN_POS_PATTERN = re.compile("aclImdb/train/pos/.*\.txt$")
|
||||||
|
TRAIN_NEG_PATTERN = re.compile("aclImdb/train/neg/.*\.txt$")
|
||||||
|
TRAIN_PATTERN = re.compile("aclImdb/train/.*\.txt$")
|
||||||
|
|
||||||
|
TEST_POS_PATTERN = re.compile("aclImdb/test/pos/.*\.txt$")
|
||||||
|
TEST_NEG_PATTERN = re.compile("aclImdb/test/neg/.*\.txt$")
|
||||||
|
TEST_PATTERN = re.compile("aclImdb/test/.*\.txt$")
|
||||||
|
|
||||||
|
|
||||||
|
class TestIMDB(unittest.TestCase):
|
||||||
|
word_idx = None
|
||||||
|
|
||||||
|
def test_build_dict(self):
|
||||||
|
if self.word_idx == None:
|
||||||
|
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN,
|
||||||
|
150)
|
||||||
|
|
||||||
|
self.assertEqual(len(self.word_idx), 7036)
|
||||||
|
|
||||||
|
def check_dataset(self, dataset, expected_size):
|
||||||
|
if self.word_idx == None:
|
||||||
|
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN,
|
||||||
|
150)
|
||||||
|
|
||||||
|
sum = 0
|
||||||
|
for l in dataset(self.word_idx):
|
||||||
|
self.assertEqual(l[1], sum % 2)
|
||||||
|
sum += 1
|
||||||
|
self.assertEqual(sum, expected_size)
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
self.check_dataset(paddle.v2.dataset.imdb.train, 25000)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
self.check_dataset(paddle.v2.dataset.imdb.test, 25000)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.imikolov
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
WORD_DICT = paddle.v2.dataset.imikolov.build_dict()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMikolov(unittest.TestCase):
|
||||||
|
def check_reader(self, reader, n):
|
||||||
|
for l in reader():
|
||||||
|
self.assertEqual(len(l), n)
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
n = 5
|
||||||
|
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)
|
||||||
|
|
||||||
|
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
|
||||||
|
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
|
||||||
|
'rake regatta rubens sim snack-food ssangyong swapo wachter'
|
||||||
|
first_line = [
|
||||||
|
WORD_DICT.get(ch, WORD_DICT['<unk>'])
|
||||||
|
for ch in first_line.split(' ')
|
||||||
|
]
|
||||||
|
for l in paddle.v2.dataset.imikolov.train(
|
||||||
|
WORD_DICT, n=-1,
|
||||||
|
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
|
||||||
|
read_line = l[0][1:]
|
||||||
|
break
|
||||||
|
self.assertEqual(first_line, read_line)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
n = 5
|
||||||
|
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)
|
||||||
|
|
||||||
|
first_line = 'consumers may want to move their telephones a little '\
|
||||||
|
'closer to the tv set'
|
||||||
|
first_line = [
|
||||||
|
WORD_DICT.get(ch, WORD_DICT['<unk>'])
|
||||||
|
for ch in first_line.split(' ')
|
||||||
|
]
|
||||||
|
for l in paddle.v2.dataset.imikolov.test(
|
||||||
|
WORD_DICT, n=-1,
|
||||||
|
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
|
||||||
|
read_line = l[0][1:]
|
||||||
|
break
|
||||||
|
self.assertEqual(first_line, read_line)
|
||||||
|
|
||||||
|
def test_total(self):
|
||||||
|
_, idx = zip(*WORD_DICT.items())
|
||||||
|
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.mnist
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMNIST(unittest.TestCase):
|
||||||
|
def check_reader(self, reader):
|
||||||
|
sum = 0
|
||||||
|
label = 0
|
||||||
|
for l in reader():
|
||||||
|
self.assertEqual(l[0].size, 784)
|
||||||
|
if l[1] > label:
|
||||||
|
label = l[1]
|
||||||
|
sum += 1
|
||||||
|
return sum, label
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.mnist.train())
|
||||||
|
self.assertEqual(instances, 60000)
|
||||||
|
self.assertEqual(max_label_value, 9)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
instances, max_label_value = self.check_reader(
|
||||||
|
paddle.v2.dataset.mnist.test())
|
||||||
|
self.assertEqual(instances, 10000)
|
||||||
|
self.assertEqual(max_label_value, 9)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.mq2007
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestMQ2007(unittest.TestCase):
|
||||||
|
def test_pairwise(self):
|
||||||
|
for label, query_left, query_right in paddle.v2.dataset.mq2007.test(
|
||||||
|
format="pairwise"):
|
||||||
|
self.assertEqual(query_left.shape(), (46, ))
|
||||||
|
self.assertEqual(query_right.shape(), (46, ))
|
||||||
|
|
||||||
|
def test_listwise(self):
|
||||||
|
for label_array, query_array in paddle.v2.dataset.mq2007.test(
|
||||||
|
format="listwise"):
|
||||||
|
self.assertEqual(len(label_array), len(query_array))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,55 @@
|
|||||||
|
# /usr/bin/env python
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import nltk
|
||||||
|
import paddle.v2.dataset.sentiment as st
|
||||||
|
from nltk.corpus import movie_reviews
|
||||||
|
|
||||||
|
|
||||||
|
class TestSentimentMethods(unittest.TestCase):
|
||||||
|
def test_get_word_dict(self):
|
||||||
|
word_dict = st.get_word_dict()[0:10]
|
||||||
|
test_word_list = [(u',', 0), (u'the', 1), (u'.', 2), (u'a', 3),
|
||||||
|
(u'and', 4), (u'of', 5), (u'to', 6), (u"'", 7),
|
||||||
|
(u'is', 8), (u'in', 9)]
|
||||||
|
for idx, each in enumerate(word_dict):
|
||||||
|
self.assertEqual(each, test_word_list[idx])
|
||||||
|
self.assertTrue("/root/.cache/paddle/dataset" in nltk.data.path)
|
||||||
|
|
||||||
|
def test_sort_files(self):
|
||||||
|
last_label = ''
|
||||||
|
for sample_file in st.sort_files():
|
||||||
|
current_label = sample_file.split("/")[0]
|
||||||
|
self.assertNotEqual(current_label, last_label)
|
||||||
|
last_label = current_label
|
||||||
|
|
||||||
|
def test_data_set(self):
|
||||||
|
data_set = st.load_sentiment_data()
|
||||||
|
last_label = -1
|
||||||
|
for each in st.test():
|
||||||
|
self.assertNotEqual(each[1], last_label)
|
||||||
|
last_label = each[1]
|
||||||
|
self.assertEqual(len(data_set), st.NUM_TOTAL_INSTANCES)
|
||||||
|
self.assertEqual(len(list(st.train())), st.NUM_TRAINING_INSTANCES)
|
||||||
|
self.assertEqual(
|
||||||
|
len(list(st.test())),
|
||||||
|
(st.NUM_TOTAL_INSTANCES - st.NUM_TRAINING_INSTANCES))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.voc2012
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestVOC(unittest.TestCase):
|
||||||
|
def check_reader(self, reader):
|
||||||
|
sum = 0
|
||||||
|
label = 0
|
||||||
|
for l in reader():
|
||||||
|
self.assertEqual(l[0].size, 3 * l[1].size)
|
||||||
|
sum += 1
|
||||||
|
return sum
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
count = self.check_reader(paddle.v2.dataset.voc_seg.train())
|
||||||
|
self.assertEqual(count, 2913)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
count = self.check_reader(paddle.v2.dataset.voc_seg.test())
|
||||||
|
self.assertEqual(count, 1464)
|
||||||
|
|
||||||
|
def test_val(self):
|
||||||
|
count = self.check_reader(paddle.v2.dataset.voc_seg.val())
|
||||||
|
self.assertEqual(count, 1449)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle.v2.dataset.wmt16
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestWMT16(unittest.TestCase):
|
||||||
|
def checkout_one_sample(self, sample):
|
||||||
|
# train data has 3 field: source language word indices,
|
||||||
|
# target language word indices, and target next word indices.
|
||||||
|
self.assertEqual(len(sample), 3)
|
||||||
|
|
||||||
|
# test start mark and end mark in source word indices.
|
||||||
|
self.assertEqual(sample[0][0], 0)
|
||||||
|
self.assertEqual(sample[0][-1], 1)
|
||||||
|
|
||||||
|
# test start mask in target word indices
|
||||||
|
self.assertEqual(sample[1][0], 0)
|
||||||
|
|
||||||
|
# test en mask in target next word indices
|
||||||
|
self.assertEqual(sample[2][-1], 1)
|
||||||
|
|
||||||
|
def test_train(self):
|
||||||
|
for idx, sample in enumerate(
|
||||||
|
paddle.v2.dataset.wmt16.train(
|
||||||
|
src_dict_size=100000, trg_dict_size=100000)()):
|
||||||
|
if idx >= 10: break
|
||||||
|
self.checkout_one_sample(sample)
|
||||||
|
|
||||||
|
def test_test(self):
|
||||||
|
for idx, sample in enumerate(
|
||||||
|
paddle.v2.dataset.wmt16.test(
|
||||||
|
src_dict_size=1000, trg_dict_size=1000)()):
|
||||||
|
if idx >= 10: break
|
||||||
|
self.checkout_one_sample(sample)
|
||||||
|
|
||||||
|
def test_val(self):
|
||||||
|
for idx, sample in enumerate(
|
||||||
|
paddle.v2.dataset.wmt16.validation(
|
||||||
|
src_dict_size=1000, trg_dict_size=1000)()):
|
||||||
|
if idx >= 10: break
|
||||||
|
self.checkout_one_sample(sample)
|
||||||
|
|
||||||
|
def test_get_dict(self):
|
||||||
|
dict_size = 1000
|
||||||
|
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
|
||||||
|
self.assertEqual(len(word_dict), dict_size)
|
||||||
|
self.assertEqual(word_dict[0], "<s>")
|
||||||
|
self.assertEqual(word_dict[1], "<e>")
|
||||||
|
self.assertEqual(word_dict[2], "<unk>")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,134 @@
|
|||||||
|
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
UCI Housing dataset.
|
||||||
|
|
||||||
|
This module will download dataset from
|
||||||
|
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and
|
||||||
|
parse training set and test set into paddle reader creators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import paddle.v2.dataset.common
|
||||||
|
from paddle.v2.parameters import Parameters
|
||||||
|
|
||||||
|
__all__ = ['train', 'test']
|
||||||
|
|
||||||
|
URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data'
|
||||||
|
MD5 = 'd4accdce7a25600298819f8e28e8d593'
|
||||||
|
feature_names = [
|
||||||
|
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
|
||||||
|
'PTRATIO', 'B', 'LSTAT', 'convert'
|
||||||
|
]
|
||||||
|
|
||||||
|
UCI_TRAIN_DATA = None
|
||||||
|
UCI_TEST_DATA = None
|
||||||
|
URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar'
|
||||||
|
MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b'
|
||||||
|
|
||||||
|
|
||||||
|
def feature_range(maximums, minimums):
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
feature_num = len(maximums)
|
||||||
|
ax.bar(range(feature_num), maximums - minimums, color='r', align='center')
|
||||||
|
ax.set_title('feature scale')
|
||||||
|
plt.xticks(range(feature_num), feature_names)
|
||||||
|
plt.xlim([-1, feature_num])
|
||||||
|
fig.set_figheight(6)
|
||||||
|
fig.set_figwidth(10)
|
||||||
|
if not os.path.exists('./image'):
|
||||||
|
os.makedirs('./image')
|
||||||
|
fig.savefig('image/ranges.png', dpi=48)
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(filename, feature_num=14, ratio=0.8):
|
||||||
|
global UCI_TRAIN_DATA, UCI_TEST_DATA
|
||||||
|
if UCI_TRAIN_DATA is not None and UCI_TEST_DATA is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = np.fromfile(filename, sep=' ')
|
||||||
|
data = data.reshape(data.shape[0] / feature_num, feature_num)
|
||||||
|
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
|
||||||
|
axis=0) / data.shape[0]
|
||||||
|
feature_range(maximums[:-1], minimums[:-1])
|
||||||
|
for i in xrange(feature_num - 1):
|
||||||
|
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
|
||||||
|
offset = int(data.shape[0] * ratio)
|
||||||
|
UCI_TRAIN_DATA = data[:offset]
|
||||||
|
UCI_TEST_DATA = data[offset:]
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""
|
||||||
|
UCI_HOUSING training set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is features after
|
||||||
|
normalization and price number.
|
||||||
|
|
||||||
|
:return: Training reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
global UCI_TRAIN_DATA
|
||||||
|
load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
for d in UCI_TRAIN_DATA:
|
||||||
|
yield d[:-1], d[-1:]
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
"""
|
||||||
|
UCI_HOUSING test set creator.
|
||||||
|
|
||||||
|
It returns a reader creator, each sample in the reader is features after
|
||||||
|
normalization and price number.
|
||||||
|
|
||||||
|
:return: Test reader creator
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
global UCI_TEST_DATA
|
||||||
|
load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
for d in UCI_TEST_DATA:
|
||||||
|
yield d[:-1], d[-1:]
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def model():
|
||||||
|
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar',
|
||||||
|
MD5_MODEL)
|
||||||
|
with open(tar_file, 'r') as f:
|
||||||
|
parameters = Parameters.from_tar(f)
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
|
def fetch():
|
||||||
|
paddle.v2.dataset.common.download(URL, 'uci_housing', MD5)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(path):
|
||||||
|
"""
|
||||||
|
Converts dataset to recordio format
|
||||||
|
"""
|
||||||
|
paddle.v2.dataset.common.convert(path, train(), 1000, "uci_housing_train")
|
||||||
|
paddle.v2.dataset.common.convert(path, test(), 1000, "uci_houseing_test")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue