parent
991b582efc
commit
f837eee724
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Dataset package.
|
||||
"""
|
||||
|
||||
import mnist
|
||||
import imikolov
|
||||
import imdb
|
||||
import cifar
|
||||
import movielens
|
||||
import conll05
|
||||
import uci_housing
|
||||
import sentiment
|
||||
import wmt14
|
||||
import wmt16
|
||||
import mq2007
|
||||
import flowers
|
||||
import voc2012
|
||||
|
||||
__all__ = [
|
||||
'mnist',
|
||||
'imikolov',
|
||||
'imdb',
|
||||
'cifar',
|
||||
'movielens',
|
||||
'conll05',
|
||||
'sentiment'
|
||||
'uci_housing',
|
||||
'wmt14',
|
||||
'wmt16',
|
||||
'mq2007',
|
||||
'flowers',
|
||||
'voc2012',
|
||||
]
|
@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
CIFAR dataset.
|
||||
|
||||
This module will download dataset from
|
||||
https://www.cs.toronto.edu/~kriz/cifar.html and parse train/test set into
|
||||
paddle reader creators.
|
||||
|
||||
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes,
|
||||
with 6000 images per class. There are 50000 training images and 10000 test
|
||||
images.
|
||||
|
||||
The CIFAR-100 dataset is just like the CIFAR-10, except it has 100 classes
|
||||
containing 600 images each. There are 500 training images and 100 testing
|
||||
images per class.
|
||||
|
||||
"""
|
||||
|
||||
import cPickle
|
||||
import itertools
|
||||
import numpy
|
||||
import paddle.v2.dataset.common
|
||||
import tarfile
|
||||
|
||||
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
|
||||
|
||||
URL_PREFIX = 'https://www.cs.toronto.edu/~kriz/'
|
||||
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
|
||||
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
|
||||
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
|
||||
|
||||
def reader_creator(filename, sub_name):
|
||||
def read_batch(batch):
|
||||
data = batch['data']
|
||||
labels = batch.get('labels', batch.get('fine_labels', None))
|
||||
assert labels is not None
|
||||
for sample, label in itertools.izip(data, labels):
|
||||
yield (sample / 255.0).astype(numpy.float32), int(label)
|
||||
|
||||
def reader():
|
||||
with tarfile.open(filename, mode='r') as f:
|
||||
names = (each_item.name for each_item in f
|
||||
if sub_name in each_item.name)
|
||||
|
||||
for name in names:
|
||||
batch = cPickle.load(f.extractfile(name))
|
||||
for item in read_batch(batch):
|
||||
yield item
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def train100():
|
||||
"""
|
||||
CIFAR-100 training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is image pixels in
|
||||
[0, 1] and label in [0, 99].
|
||||
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
|
||||
'train')
|
||||
|
||||
|
||||
def test100():
|
||||
"""
|
||||
CIFAR-100 test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is image pixels in
|
||||
[0, 1] and label in [0, 9].
|
||||
|
||||
:return: Test reader creator.
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
|
||||
'test')
|
||||
|
||||
|
||||
def train10():
|
||||
"""
|
||||
CIFAR-10 training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is image pixels in
|
||||
[0, 1] and label in [0, 9].
|
||||
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
|
||||
'data_batch')
|
||||
|
||||
|
||||
def test10():
|
||||
"""
|
||||
CIFAR-10 test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is image pixels in
|
||||
[0, 1] and label in [0, 9].
|
||||
|
||||
:return: Test reader creator.
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
|
||||
'test_batch')
|
||||
|
||||
|
||||
def fetch():
|
||||
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
|
||||
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
|
||||
|
||||
|
||||
def convert(path):
|
||||
"""
|
||||
Converts dataset to recordio format
|
||||
"""
|
||||
paddle.v2.dataset.common.convert(path, train100(), 1000, "cifar_train100")
|
||||
paddle.v2.dataset.common.convert(path, test100(), 1000, "cifar_test100")
|
||||
paddle.v2.dataset.common.convert(path, train10(), 1000, "cifar_train10")
|
||||
paddle.v2.dataset.common.convert(path, test10(), 1000, "cifar_test10")
|
@ -0,0 +1,236 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import requests
|
||||
import hashlib
|
||||
import os
|
||||
import errno
|
||||
import shutil
|
||||
import sys
|
||||
import importlib
|
||||
import paddle.v2.dataset
|
||||
import cPickle
|
||||
import glob
|
||||
import cPickle as pickle
|
||||
|
||||
__all__ = [
|
||||
'DATA_HOME',
|
||||
'download',
|
||||
'md5file',
|
||||
'split',
|
||||
'cluster_files_reader',
|
||||
'convert',
|
||||
]
|
||||
|
||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
|
||||
|
||||
|
||||
# When running unit tests, there could be multiple processes that
|
||||
# trying to create DATA_HOME directory simultaneously, so we cannot
|
||||
# use a if condition to check for the existence of the directory;
|
||||
# instead, we use the filesystem as the synchronization mechanism by
|
||||
# catching returned errors.
|
||||
def must_mkdirs(path):
|
||||
try:
|
||||
os.makedirs(DATA_HOME)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.EEXIST:
|
||||
raise
|
||||
pass
|
||||
|
||||
|
||||
must_mkdirs(DATA_HOME)
|
||||
|
||||
|
||||
def md5file(fname):
|
||||
hash_md5 = hashlib.md5()
|
||||
f = open(fname, "rb")
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
f.close()
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def download(url, module_name, md5sum, save_name=None):
|
||||
dirname = os.path.join(DATA_HOME, module_name)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
filename = os.path.join(dirname,
|
||||
url.split('/')[-1]
|
||||
if save_name is None else save_name)
|
||||
|
||||
retry = 0
|
||||
retry_limit = 3
|
||||
while not (os.path.exists(filename) and md5file(filename) == md5sum):
|
||||
if os.path.exists(filename):
|
||||
print "file md5", md5file(filename), md5sum
|
||||
if retry < retry_limit:
|
||||
retry += 1
|
||||
else:
|
||||
raise RuntimeError("Cannot download {0} within retry limit {1}".
|
||||
format(url, retry_limit))
|
||||
print "Cache file %s not found, downloading %s" % (filename, url)
|
||||
r = requests.get(url, stream=True)
|
||||
total_length = r.headers.get('content-length')
|
||||
|
||||
if total_length is None:
|
||||
with open(filename, 'w') as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
else:
|
||||
with open(filename, 'w') as f:
|
||||
dl = 0
|
||||
total_length = int(total_length)
|
||||
for data in r.iter_content(chunk_size=4096):
|
||||
dl += len(data)
|
||||
f.write(data)
|
||||
done = int(50 * dl / total_length)
|
||||
sys.stdout.write("\r[%s%s]" % ('=' * done,
|
||||
' ' * (50 - done)))
|
||||
sys.stdout.flush()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def fetch_all():
|
||||
for module_name in filter(lambda x: not x.startswith("__"),
|
||||
dir(paddle.v2.dataset)):
|
||||
if "fetch" in dir(
|
||||
importlib.import_module("paddle.v2.dataset.%s" % module_name)):
|
||||
getattr(
|
||||
importlib.import_module("paddle.v2.dataset.%s" % module_name),
|
||||
"fetch")()
|
||||
|
||||
|
||||
def fetch_all_recordio(path):
|
||||
for module_name in filter(lambda x: not x.startswith("__"),
|
||||
dir(paddle.v2.dataset)):
|
||||
if "convert" in dir(
|
||||
importlib.import_module("paddle.v2.dataset.%s" % module_name)) and \
|
||||
not module_name == "common":
|
||||
ds_path = os.path.join(path, module_name)
|
||||
must_mkdirs(ds_path)
|
||||
getattr(
|
||||
importlib.import_module("paddle.v2.dataset.%s" % module_name),
|
||||
"convert")(ds_path)
|
||||
|
||||
|
||||
def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
|
||||
"""
|
||||
you can call the function as:
|
||||
|
||||
split(paddle.v2.dataset.cifar.train10(), line_count=1000,
|
||||
suffix="imikolov-train-%05d.pickle")
|
||||
|
||||
the output files as:
|
||||
|
||||
|-imikolov-train-00000.pickle
|
||||
|-imikolov-train-00001.pickle
|
||||
|- ...
|
||||
|-imikolov-train-00480.pickle
|
||||
|
||||
:param reader: is a reader creator
|
||||
:param line_count: line count for each file
|
||||
:param suffix: the suffix for the output files, should contain "%d"
|
||||
means the id for each file. Default is "%05d.pickle"
|
||||
:param dumper: is a callable function that dump object to file, this
|
||||
function will be called as dumper(obj, f) and obj is the object
|
||||
will be dumped, f is a file object. Default is cPickle.dump.
|
||||
"""
|
||||
if not callable(dumper):
|
||||
raise TypeError("dumper should be callable.")
|
||||
lines = []
|
||||
indx_f = 0
|
||||
for i, d in enumerate(reader()):
|
||||
lines.append(d)
|
||||
if i >= line_count and i % line_count == 0:
|
||||
with open(suffix % indx_f, "w") as f:
|
||||
dumper(lines, f)
|
||||
lines = []
|
||||
indx_f += 1
|
||||
if lines:
|
||||
with open(suffix % indx_f, "w") as f:
|
||||
dumper(lines, f)
|
||||
|
||||
|
||||
def cluster_files_reader(files_pattern,
|
||||
trainer_count,
|
||||
trainer_id,
|
||||
loader=cPickle.load):
|
||||
"""
|
||||
Create a reader that yield element from the given files, select
|
||||
a file set according trainer count and trainer_id
|
||||
|
||||
:param files_pattern: the files which generating by split(...)
|
||||
:param trainer_count: total trainer count
|
||||
:param trainer_id: the trainer rank id
|
||||
:param loader: is a callable function that load object from file, this
|
||||
function will be called as loader(f) and f is a file object.
|
||||
Default is cPickle.load
|
||||
"""
|
||||
|
||||
def reader():
|
||||
if not callable(loader):
|
||||
raise TypeError("loader should be callable.")
|
||||
file_list = glob.glob(files_pattern)
|
||||
file_list.sort()
|
||||
my_file_list = []
|
||||
for idx, fn in enumerate(file_list):
|
||||
if idx % trainer_count == trainer_id:
|
||||
print "append file: %s" % fn
|
||||
my_file_list.append(fn)
|
||||
for fn in my_file_list:
|
||||
with open(fn, "r") as f:
|
||||
lines = loader(f)
|
||||
for line in lines:
|
||||
yield line
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def convert(output_path, reader, line_count, name_prefix):
|
||||
import recordio
|
||||
"""
|
||||
Convert data from reader to recordio format files.
|
||||
|
||||
:param output_path: directory in which output files will be saved.
|
||||
:param reader: a data reader, from which the convert program will read
|
||||
data instances.
|
||||
:param name_prefix: the name prefix of generated files.
|
||||
:param max_lines_to_shuffle: the max lines numbers to shuffle before
|
||||
writing.
|
||||
"""
|
||||
|
||||
assert line_count >= 1
|
||||
indx_f = 0
|
||||
|
||||
def write_data(indx_f, lines):
|
||||
filename = "%s/%s-%05d" % (output_path, name_prefix, indx_f)
|
||||
writer = recordio.writer(filename)
|
||||
for l in lines:
|
||||
# FIXME(Yancey1989):
|
||||
# dumps with protocol: pickle.HIGHEST_PROTOCOL
|
||||
writer.write(cPickle.dumps(l))
|
||||
writer.close()
|
||||
|
||||
lines = []
|
||||
for i, d in enumerate(reader()):
|
||||
lines.append(d)
|
||||
if i % line_count == 0 and i >= line_count:
|
||||
write_data(indx_f, lines)
|
||||
lines = []
|
||||
indx_f += 1
|
||||
continue
|
||||
|
||||
write_data(indx_f, lines)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This module will download dataset from
|
||||
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
|
||||
and parse train/test set intopaddle reader creators.
|
||||
|
||||
This set contains images of flowers belonging to 102 different categories.
|
||||
The images were acquired by searching the web and taking pictures. There are a
|
||||
minimum of 40 images for each category.
|
||||
|
||||
The database was used in:
|
||||
|
||||
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
|
||||
number of classes.Proceedings of the Indian Conference on Computer Vision,
|
||||
Graphics and Image Processing (2008)
|
||||
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
|
||||
|
||||
"""
|
||||
import cPickle
|
||||
import itertools
|
||||
import functools
|
||||
from common import download
|
||||
import tarfile
|
||||
import scipy.io as scio
|
||||
from paddle.v2.image import *
|
||||
from paddle.v2.reader import *
|
||||
import os
|
||||
import numpy as np
|
||||
from multiprocessing import cpu_count
|
||||
__all__ = ['train', 'test', 'valid']
|
||||
|
||||
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
|
||||
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
|
||||
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
|
||||
DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118'
|
||||
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
|
||||
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
|
||||
# In official 'readme', tstid is the flag of test data
|
||||
# and trnid is the flag of train data. But test data is more than train data.
|
||||
# So we exchange the train data and test data.
|
||||
TRAIN_FLAG = 'tstid'
|
||||
TEST_FLAG = 'trnid'
|
||||
VALID_FLAG = 'valid'
|
||||
|
||||
|
||||
def default_mapper(is_train, sample):
|
||||
'''
|
||||
map image bytes data to type needed by model input layer
|
||||
'''
|
||||
img, label = sample
|
||||
img = load_image_bytes(img)
|
||||
img = simple_transform(
|
||||
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
|
||||
return img.flatten().astype('float32'), label
|
||||
|
||||
|
||||
train_mapper = functools.partial(default_mapper, True)
|
||||
test_mapper = functools.partial(default_mapper, False)
|
||||
|
||||
|
||||
def reader_creator(data_file,
|
||||
label_file,
|
||||
setid_file,
|
||||
dataset_name,
|
||||
mapper,
|
||||
buffered_size=1024,
|
||||
use_xmap=True):
|
||||
'''
|
||||
1. read images from tar file and
|
||||
merge images into batch files in 102flowers.tgz_batch/
|
||||
2. get a reader to read sample from batch file
|
||||
|
||||
:param data_file: downloaded data file
|
||||
:type data_file: string
|
||||
:param label_file: downloaded label file
|
||||
:type label_file: string
|
||||
:param setid_file: downloaded setid file containing information
|
||||
about how to split dataset
|
||||
:type setid_file: string
|
||||
:param dataset_name: data set name (tstid|trnid|valid)
|
||||
:type dataset_name: string
|
||||
:param mapper: a function to map image bytes data to type
|
||||
needed by model input layer
|
||||
:type mapper: callable
|
||||
:param buffered_size: the size of buffer used to process images
|
||||
:type buffered_size: int
|
||||
:return: data reader
|
||||
:rtype: callable
|
||||
'''
|
||||
labels = scio.loadmat(label_file)['labels'][0]
|
||||
indexes = scio.loadmat(setid_file)[dataset_name][0]
|
||||
img2label = {}
|
||||
for i in indexes:
|
||||
img = "jpg/image_%05d.jpg" % i
|
||||
img2label[img] = labels[i - 1]
|
||||
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
|
||||
|
||||
def reader():
|
||||
for file in open(file_list):
|
||||
file = file.strip()
|
||||
batch = None
|
||||
with open(file, 'r') as f:
|
||||
batch = cPickle.load(f)
|
||||
data = batch['data']
|
||||
labels = batch['label']
|
||||
for sample, label in itertools.izip(data, batch['label']):
|
||||
yield sample, int(label) - 1
|
||||
|
||||
if use_xmap:
|
||||
return xmap_readers(mapper, reader, cpu_count(), buffered_size)
|
||||
else:
|
||||
return map_readers(mapper, reader)
|
||||
|
||||
|
||||
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
|
||||
'''
|
||||
Create flowers training set reader.
|
||||
It returns a reader, each sample in the reader is
|
||||
image pixels in [0, 1] and label in [1, 102]
|
||||
translated from original color image by steps:
|
||||
1. resize to 256*256
|
||||
2. random crop to 224*224
|
||||
3. flatten
|
||||
:param mapper: a function to map sample.
|
||||
:type mapper: callable
|
||||
:param buffered_size: the size of buffer used to process images
|
||||
:type buffered_size: int
|
||||
:return: train data reader
|
||||
:rtype: callable
|
||||
'''
|
||||
return reader_creator(
|
||||
download(DATA_URL, 'flowers', DATA_MD5),
|
||||
download(LABEL_URL, 'flowers', LABEL_MD5),
|
||||
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
|
||||
buffered_size, use_xmap)
|
||||
|
||||
|
||||
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
|
||||
'''
|
||||
Create flowers test set reader.
|
||||
It returns a reader, each sample in the reader is
|
||||
image pixels in [0, 1] and label in [1, 102]
|
||||
translated from original color image by steps:
|
||||
1. resize to 256*256
|
||||
2. random crop to 224*224
|
||||
3. flatten
|
||||
:param mapper: a function to map sample.
|
||||
:type mapper: callable
|
||||
:param buffered_size: the size of buffer used to process images
|
||||
:type buffered_size: int
|
||||
:return: test data reader
|
||||
:rtype: callable
|
||||
'''
|
||||
return reader_creator(
|
||||
download(DATA_URL, 'flowers', DATA_MD5),
|
||||
download(LABEL_URL, 'flowers', LABEL_MD5),
|
||||
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
|
||||
buffered_size, use_xmap)
|
||||
|
||||
|
||||
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
|
||||
'''
|
||||
Create flowers validation set reader.
|
||||
It returns a reader, each sample in the reader is
|
||||
image pixels in [0, 1] and label in [1, 102]
|
||||
translated from original color image by steps:
|
||||
1. resize to 256*256
|
||||
2. random crop to 224*224
|
||||
3. flatten
|
||||
:param mapper: a function to map sample.
|
||||
:type mapper: callable
|
||||
:param buffered_size: the size of buffer used to process images
|
||||
:type buffered_size: int
|
||||
:return: test data reader
|
||||
:rtype: callable
|
||||
'''
|
||||
return reader_creator(
|
||||
download(DATA_URL, 'flowers', DATA_MD5),
|
||||
download(LABEL_URL, 'flowers', LABEL_MD5),
|
||||
download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
|
||||
buffered_size, use_xmap)
|
||||
|
||||
|
||||
def fetch():
|
||||
download(DATA_URL, 'flowers', DATA_MD5)
|
||||
download(LABEL_URL, 'flowers', LABEL_MD5)
|
||||
download(SETID_URL, 'flowers', SETID_MD5)
|
@ -0,0 +1,148 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
IMDB dataset.
|
||||
|
||||
This module downloads IMDB dataset from
|
||||
http://ai.stanford.edu/%7Eamaas/data/sentiment/. This dataset contains a set
|
||||
of 25,000 highly polar movie reviews for training, and 25,000 for testing.
|
||||
Besides, this module also provides API for building dictionary.
|
||||
"""
|
||||
|
||||
import paddle.v2.dataset.common
|
||||
import collections
|
||||
import tarfile
|
||||
import re
|
||||
import string
|
||||
|
||||
__all__ = ['build_dict', 'train', 'test', 'convert']
|
||||
|
||||
URL = 'http://ai.stanford.edu/%7Eamaas/data/sentiment/aclImdb_v1.tar.gz'
|
||||
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
|
||||
|
||||
|
||||
def tokenize(pattern):
|
||||
"""
|
||||
Read files that match the given pattern. Tokenize and yield each file.
|
||||
"""
|
||||
|
||||
with tarfile.open(paddle.v2.dataset.common.download(URL, 'imdb',
|
||||
MD5)) as tarf:
|
||||
# Note that we should use tarfile.next(), which does
|
||||
# sequential access of member files, other than
|
||||
# tarfile.extractfile, which does random access and might
|
||||
# destroy hard disks.
|
||||
tf = tarf.next()
|
||||
while tf != None:
|
||||
if bool(pattern.match(tf.name)):
|
||||
# newline and punctuations removal and ad-hoc tokenization.
|
||||
yield tarf.extractfile(tf).read().rstrip("\n\r").translate(
|
||||
None, string.punctuation).lower().split()
|
||||
tf = tarf.next()
|
||||
|
||||
|
||||
def build_dict(pattern, cutoff):
|
||||
"""
|
||||
Build a word dictionary from the corpus. Keys of the dictionary are words,
|
||||
and values are zero-based IDs of these words.
|
||||
"""
|
||||
word_freq = collections.defaultdict(int)
|
||||
for doc in tokenize(pattern):
|
||||
for word in doc:
|
||||
word_freq[word] += 1
|
||||
|
||||
# Not sure if we should prune less-frequent words here.
|
||||
word_freq = filter(lambda x: x[1] > cutoff, word_freq.items())
|
||||
|
||||
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
||||
words, _ = list(zip(*dictionary))
|
||||
word_idx = dict(zip(words, xrange(len(words))))
|
||||
word_idx['<unk>'] = len(words)
|
||||
return word_idx
|
||||
|
||||
|
||||
def reader_creator(pos_pattern, neg_pattern, word_idx):
|
||||
UNK = word_idx['<unk>']
|
||||
INS = []
|
||||
|
||||
def load(pattern, out, label):
|
||||
for doc in tokenize(pattern):
|
||||
out.append(([word_idx.get(w, UNK) for w in doc], label))
|
||||
|
||||
load(pos_pattern, INS, 0)
|
||||
load(neg_pattern, INS, 1)
|
||||
|
||||
def reader():
|
||||
for doc, label in INS:
|
||||
yield doc, label
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def train(word_idx):
|
||||
"""
|
||||
IMDB training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is an zero-based ID
|
||||
sequence and label in [0, 1].
|
||||
|
||||
:param word_idx: word dictionary
|
||||
:type word_idx: dict
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
re.compile("aclImdb/train/pos/.*\.txt$"),
|
||||
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx)
|
||||
|
||||
|
||||
def test(word_idx):
|
||||
"""
|
||||
IMDB test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is an zero-based ID
|
||||
sequence and label in [0, 1].
|
||||
|
||||
:param word_idx: word dictionary
|
||||
:type word_idx: dict
|
||||
:return: Test reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
re.compile("aclImdb/test/pos/.*\.txt$"),
|
||||
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx)
|
||||
|
||||
|
||||
def word_dict():
|
||||
"""
|
||||
Build a word dictionary from the corpus.
|
||||
|
||||
:return: Word dictionary
|
||||
:rtype: dict
|
||||
"""
|
||||
return build_dict(
|
||||
re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$"), 150)
|
||||
|
||||
|
||||
def fetch():
|
||||
paddle.v2.dataset.common.download(URL, 'imdb', MD5)
|
||||
|
||||
|
||||
def convert(path):
|
||||
"""
|
||||
Converts dataset to recordio format
|
||||
"""
|
||||
w = word_dict()
|
||||
paddle.v2.dataset.common.convert(path, lambda: train(w), 1000, "imdb_train")
|
||||
paddle.v2.dataset.common.convert(path, lambda: test(w), 1000, "imdb_test")
|
@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
imikolov's simple dataset.
|
||||
|
||||
This module will download dataset from
|
||||
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
|
||||
into paddle reader creators.
|
||||
"""
|
||||
import paddle.v2.dataset.common
|
||||
import collections
|
||||
import tarfile
|
||||
|
||||
__all__ = ['train', 'test', 'build_dict', 'convert']
|
||||
|
||||
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
|
||||
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
|
||||
|
||||
|
||||
class DataType(object):
|
||||
NGRAM = 1
|
||||
SEQ = 2
|
||||
|
||||
|
||||
def word_count(f, word_freq=None):
|
||||
if word_freq is None:
|
||||
word_freq = collections.defaultdict(int)
|
||||
|
||||
for l in f:
|
||||
for w in l.strip().split():
|
||||
word_freq[w] += 1
|
||||
word_freq['<s>'] += 1
|
||||
word_freq['<e>'] += 1
|
||||
|
||||
return word_freq
|
||||
|
||||
|
||||
def build_dict(min_word_freq=50):
|
||||
"""
|
||||
Build a word dictionary from the corpus, Keys of the dictionary are words,
|
||||
and values are zero-based IDs of these words.
|
||||
"""
|
||||
train_filename = './simple-examples/data/ptb.train.txt'
|
||||
test_filename = './simple-examples/data/ptb.valid.txt'
|
||||
with tarfile.open(
|
||||
paddle.v2.dataset.common.download(
|
||||
paddle.v2.dataset.imikolov.URL, 'imikolov',
|
||||
paddle.v2.dataset.imikolov.MD5)) as tf:
|
||||
trainf = tf.extractfile(train_filename)
|
||||
testf = tf.extractfile(test_filename)
|
||||
word_freq = word_count(testf, word_count(trainf))
|
||||
if '<unk>' in word_freq:
|
||||
# remove <unk> for now, since we will set it as last index
|
||||
del word_freq['<unk>']
|
||||
|
||||
word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items())
|
||||
|
||||
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
||||
words, _ = list(zip(*word_freq_sorted))
|
||||
word_idx = dict(zip(words, xrange(len(words))))
|
||||
word_idx['<unk>'] = len(words)
|
||||
|
||||
return word_idx
|
||||
|
||||
|
||||
def reader_creator(filename, word_idx, n, data_type):
|
||||
def reader():
|
||||
with tarfile.open(
|
||||
paddle.v2.dataset.common.download(
|
||||
paddle.v2.dataset.imikolov.URL, 'imikolov',
|
||||
paddle.v2.dataset.imikolov.MD5)) as tf:
|
||||
f = tf.extractfile(filename)
|
||||
|
||||
UNK = word_idx['<unk>']
|
||||
for l in f:
|
||||
if DataType.NGRAM == data_type:
|
||||
assert n > -1, 'Invalid gram length'
|
||||
l = ['<s>'] + l.strip().split() + ['<e>']
|
||||
if len(l) >= n:
|
||||
l = [word_idx.get(w, UNK) for w in l]
|
||||
for i in range(n, len(l) + 1):
|
||||
yield tuple(l[i - n:i])
|
||||
elif DataType.SEQ == data_type:
|
||||
l = l.strip().split()
|
||||
l = [word_idx.get(w, UNK) for w in l]
|
||||
src_seq = [word_idx['<s>']] + l
|
||||
trg_seq = l + [word_idx['<e>']]
|
||||
if n > 0 and len(src_seq) > n: continue
|
||||
yield src_seq, trg_seq
|
||||
else:
|
||||
assert False, 'Unknow data type'
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def train(word_idx, n, data_type=DataType.NGRAM):
|
||||
"""
|
||||
imikolov training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is a word ID
|
||||
tuple.
|
||||
|
||||
:param word_idx: word dictionary
|
||||
:type word_idx: dict
|
||||
:param n: sliding window size if type is ngram, otherwise max length of sequence
|
||||
:type n: int
|
||||
:param data_type: data type (ngram or sequence)
|
||||
:type data_type: member variable of DataType (NGRAM or SEQ)
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
|
||||
data_type)
|
||||
|
||||
|
||||
def test(word_idx, n, data_type=DataType.NGRAM):
|
||||
"""
|
||||
imikolov test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is a word ID
|
||||
tuple.
|
||||
|
||||
:param word_idx: word dictionary
|
||||
:type word_idx: dict
|
||||
:param n: sliding window size if type is ngram, otherwise max length of sequence
|
||||
:type n: int
|
||||
:param data_type: data type (ngram or sequence)
|
||||
:type data_type: member variable of DataType (NGRAM or SEQ)
|
||||
:return: Test reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
|
||||
data_type)
|
||||
|
||||
|
||||
def fetch():
|
||||
paddle.v2.dataset.common.download(URL, "imikolov", MD5)
|
||||
|
||||
|
||||
def convert(path):
|
||||
"""
|
||||
Converts dataset to recordio format
|
||||
"""
|
||||
N = 5
|
||||
word_dict = build_dict()
|
||||
paddle.v2.dataset.common.convert(path,
|
||||
train(word_dict, N), 1000,
|
||||
"imikolov_train")
|
||||
paddle.v2.dataset.common.convert(path,
|
||||
test(word_dict, N), 1000, "imikolov_test")
|
@ -0,0 +1,123 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
MNIST dataset.
|
||||
|
||||
This module will download dataset from http://yann.lecun.com/exdb/mnist/ and
|
||||
parse training set and test set into paddle reader creators.
|
||||
"""
|
||||
import paddle.v2.dataset.common
|
||||
import subprocess
|
||||
import numpy
|
||||
import platform
|
||||
__all__ = ['train', 'test', 'convert']
|
||||
|
||||
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
|
||||
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
|
||||
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
|
||||
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
|
||||
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
|
||||
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
|
||||
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
|
||||
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
|
||||
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
|
||||
|
||||
|
||||
def reader_creator(image_filename, label_filename, buffer_size):
|
||||
def reader():
|
||||
if platform.system() == 'Darwin':
|
||||
zcat_cmd = 'gzcat'
|
||||
elif platform.system() == 'Linux':
|
||||
zcat_cmd = 'zcat'
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# According to http://stackoverflow.com/a/38061619/724872, we
|
||||
# cannot use standard package gzip here.
|
||||
m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE)
|
||||
m.stdout.read(16) # skip some magic bytes
|
||||
|
||||
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE)
|
||||
l.stdout.read(8) # skip some magic bytes
|
||||
|
||||
try: # reader could be break.
|
||||
while True:
|
||||
labels = numpy.fromfile(
|
||||
l.stdout, 'ubyte', count=buffer_size).astype("int")
|
||||
|
||||
if labels.size != buffer_size:
|
||||
break # numpy.fromfile returns empty slice after EOF.
|
||||
|
||||
images = numpy.fromfile(
|
||||
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape(
|
||||
(buffer_size, 28 * 28)).astype('float32')
|
||||
|
||||
images = images / 255.0 * 2.0 - 1.0
|
||||
|
||||
for i in xrange(buffer_size):
|
||||
yield images[i, :], int(labels[i])
|
||||
finally:
|
||||
m.terminate()
|
||||
l.terminate()
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
MNIST training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is image pixels in
|
||||
[0, 1] and label in [0, 9].
|
||||
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist',
|
||||
TRAIN_IMAGE_MD5),
|
||||
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist',
|
||||
TRAIN_LABEL_MD5), 100)
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
MNIST test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is image pixels in
|
||||
[0, 1] and label in [0, 9].
|
||||
|
||||
:return: Test reader creator.
|
||||
:rtype: callable
|
||||
"""
|
||||
return reader_creator(
|
||||
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist',
|
||||
TEST_IMAGE_MD5),
|
||||
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist',
|
||||
TEST_LABEL_MD5), 100)
|
||||
|
||||
|
||||
def fetch():
|
||||
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
|
||||
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
|
||||
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
|
||||
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
|
||||
|
||||
|
||||
def convert(path):
|
||||
"""
|
||||
Converts dataset to recordio format
|
||||
"""
|
||||
paddle.v2.dataset.common.convert(path, train(), 1000, "minist_train")
|
||||
paddle.v2.dataset.common.convert(path, test(), 1000, "minist_test")
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,141 @@
|
||||
# /usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The script fetch and preprocess movie_reviews data set that provided by NLTK
|
||||
|
||||
TODO(yuyang18): Complete dataset.
|
||||
"""
|
||||
|
||||
import collections
|
||||
from itertools import chain
|
||||
|
||||
import nltk
|
||||
from nltk.corpus import movie_reviews
|
||||
|
||||
import paddle.v2.dataset.common
|
||||
|
||||
__all__ = ['train', 'test', 'get_word_dict', 'convert']
|
||||
NUM_TRAINING_INSTANCES = 1600
|
||||
NUM_TOTAL_INSTANCES = 2000
|
||||
|
||||
|
||||
def download_data_if_not_yet():
|
||||
"""
|
||||
Download the data set, if the data set is not download.
|
||||
"""
|
||||
try:
|
||||
# make sure that nltk can find the data
|
||||
if paddle.v2.dataset.common.DATA_HOME not in nltk.data.path:
|
||||
nltk.data.path.append(paddle.v2.dataset.common.DATA_HOME)
|
||||
movie_reviews.categories()
|
||||
except LookupError:
|
||||
print "Downloading movie_reviews data set, please wait....."
|
||||
nltk.download(
|
||||
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
|
||||
print "Download data set success....."
|
||||
print "Path is " + nltk.data.find('corpora/movie_reviews').path
|
||||
|
||||
|
||||
def get_word_dict():
|
||||
"""
|
||||
Sorted the words by the frequency of words which occur in sample
|
||||
:return:
|
||||
words_freq_sorted
|
||||
"""
|
||||
words_freq_sorted = list()
|
||||
word_freq_dict = collections.defaultdict(int)
|
||||
download_data_if_not_yet()
|
||||
|
||||
for category in movie_reviews.categories():
|
||||
for field in movie_reviews.fileids(category):
|
||||
for words in movie_reviews.words(field):
|
||||
word_freq_dict[words] += 1
|
||||
words_sort_list = word_freq_dict.items()
|
||||
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
|
||||
for index, word in enumerate(words_sort_list):
|
||||
words_freq_sorted.append((word[0], index))
|
||||
return words_freq_sorted
|
||||
|
||||
|
||||
def sort_files():
|
||||
"""
|
||||
Sorted the sample for cross reading the sample
|
||||
:return:
|
||||
files_list
|
||||
"""
|
||||
files_list = list()
|
||||
neg_file_list = movie_reviews.fileids('neg')
|
||||
pos_file_list = movie_reviews.fileids('pos')
|
||||
files_list = list(chain.from_iterable(zip(neg_file_list, pos_file_list)))
|
||||
return files_list
|
||||
|
||||
|
||||
def load_sentiment_data():
|
||||
"""
|
||||
Load the data set
|
||||
:return:
|
||||
data_set
|
||||
"""
|
||||
data_set = list()
|
||||
download_data_if_not_yet()
|
||||
words_ids = dict(get_word_dict())
|
||||
for sample_file in sort_files():
|
||||
words_list = list()
|
||||
category = 0 if 'neg' in sample_file else 1
|
||||
for word in movie_reviews.words(sample_file):
|
||||
words_list.append(words_ids[word.lower()])
|
||||
data_set.append((words_list, category))
|
||||
return data_set
|
||||
|
||||
|
||||
def reader_creator(data):
|
||||
"""
|
||||
Reader creator, generate an iterator for data set
|
||||
:param data:
|
||||
train data set or test data set
|
||||
"""
|
||||
for each in data:
|
||||
yield each[0], each[1]
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
Default training set reader creator
|
||||
"""
|
||||
data_set = load_sentiment_data()
|
||||
return reader_creator(data_set[0:NUM_TRAINING_INSTANCES])
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
Default test set reader creator
|
||||
"""
|
||||
data_set = load_sentiment_data()
|
||||
return reader_creator(data_set[NUM_TRAINING_INSTANCES:])
|
||||
|
||||
|
||||
def fetch():
|
||||
nltk.download(
|
||||
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
|
||||
|
||||
|
||||
def convert(path):
|
||||
"""
|
||||
Converts dataset to recordio format
|
||||
"""
|
||||
paddle.v2.dataset.common.convert(path, train, 1000, "sentiment_train")
|
||||
paddle.v2.dataset.common.convert(path, test, 1000, "sentiment_test")
|
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.cifar
|
||||
import unittest
|
||||
|
||||
|
||||
class TestCIFAR(unittest.TestCase):
|
||||
def check_reader(self, reader):
|
||||
sum = 0
|
||||
label = 0
|
||||
for l in reader():
|
||||
self.assertEqual(l[0].size, 3072)
|
||||
if l[1] > label:
|
||||
label = l[1]
|
||||
sum += 1
|
||||
return sum, label
|
||||
|
||||
def test_test10(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.cifar.test10())
|
||||
self.assertEqual(instances, 10000)
|
||||
self.assertEqual(max_label_value, 9)
|
||||
|
||||
def test_train10(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.cifar.train10())
|
||||
self.assertEqual(instances, 50000)
|
||||
self.assertEqual(max_label_value, 9)
|
||||
|
||||
def test_test100(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.cifar.test100())
|
||||
self.assertEqual(instances, 10000)
|
||||
self.assertEqual(max_label_value, 99)
|
||||
|
||||
def test_train100(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.cifar.train100())
|
||||
self.assertEqual(instances, 50000)
|
||||
self.assertEqual(max_label_value, 99)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.common
|
||||
import unittest
|
||||
import tempfile
|
||||
import glob
|
||||
|
||||
|
||||
class TestCommon(unittest.TestCase):
|
||||
def test_md5file(self):
|
||||
_, temp_path = tempfile.mkstemp()
|
||||
with open(temp_path, 'w') as f:
|
||||
f.write("Hello\n")
|
||||
self.assertEqual('09f7e02f1290be211da707a266f153b3',
|
||||
paddle.v2.dataset.common.md5file(temp_path))
|
||||
|
||||
def test_download(self):
|
||||
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
|
||||
self.assertEqual(
|
||||
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
|
||||
paddle.v2.dataset.common.download(
|
||||
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))
|
||||
|
||||
def test_split(self):
|
||||
def test_reader():
|
||||
def reader():
|
||||
for x in xrange(10):
|
||||
yield x
|
||||
|
||||
return reader
|
||||
|
||||
_, temp_path = tempfile.mkstemp()
|
||||
paddle.v2.dataset.common.split(
|
||||
test_reader(), 4, suffix=temp_path + '/test-%05d.pickle')
|
||||
files = glob.glob(temp_path + '/test-%05d.pickle')
|
||||
self.assertEqual(len(files), 3)
|
||||
|
||||
def test_cluster_file_reader(self):
|
||||
_, temp_path = tempfile.mkstemp()
|
||||
for x in xrange(5):
|
||||
with open(temp_path + '/%05d.test' % x) as f:
|
||||
f.write('%d\n' % x)
|
||||
reader = paddle.v2.dataset.common.cluster_files_reader(
|
||||
temp_path + '/*.test', 5, 0)
|
||||
for idx, e in enumerate(reader()):
|
||||
self.assertEqual(e, str("0"))
|
||||
|
||||
def test_convert(self):
|
||||
record_num = 10
|
||||
num_shards = 4
|
||||
|
||||
def test_reader():
|
||||
def reader():
|
||||
for x in xrange(record_num):
|
||||
yield x
|
||||
|
||||
return reader
|
||||
|
||||
path = tempfile.mkdtemp()
|
||||
paddle.v2.dataset.common.convert(path,
|
||||
test_reader(), num_shards,
|
||||
'random_images')
|
||||
|
||||
files = glob.glob(path + '/random_images-*')
|
||||
self.assertEqual(len(files), num_shards)
|
||||
|
||||
recs = []
|
||||
for i in range(0, num_shards):
|
||||
n = "%s/random_images-%05d-of-%05d" % (path, i, num_shards - 1)
|
||||
r = recordio.reader(n)
|
||||
while True:
|
||||
d = r.read()
|
||||
if d is None:
|
||||
break
|
||||
recs.append(d)
|
||||
|
||||
recs.sort()
|
||||
self.assertEqual(total, record_num)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.flowers
|
||||
import unittest
|
||||
|
||||
|
||||
class TestFlowers(unittest.TestCase):
|
||||
def check_reader(self, reader):
|
||||
sum = 0
|
||||
label = 0
|
||||
size = 224 * 224 * 3
|
||||
for l in reader():
|
||||
self.assertEqual(l[0].size, size)
|
||||
if l[1] > label:
|
||||
label = l[1]
|
||||
sum += 1
|
||||
return sum, label
|
||||
|
||||
def test_train(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.flowers.train())
|
||||
self.assertEqual(instances, 6149)
|
||||
self.assertEqual(max_label_value, 102)
|
||||
|
||||
def test_test(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.flowers.test())
|
||||
self.assertEqual(instances, 1020)
|
||||
self.assertEqual(max_label_value, 102)
|
||||
|
||||
def test_valid(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.flowers.valid())
|
||||
self.assertEqual(instances, 1020)
|
||||
self.assertEqual(max_label_value, 102)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.imdb
|
||||
import unittest
|
||||
import re
|
||||
|
||||
TRAIN_POS_PATTERN = re.compile("aclImdb/train/pos/.*\.txt$")
|
||||
TRAIN_NEG_PATTERN = re.compile("aclImdb/train/neg/.*\.txt$")
|
||||
TRAIN_PATTERN = re.compile("aclImdb/train/.*\.txt$")
|
||||
|
||||
TEST_POS_PATTERN = re.compile("aclImdb/test/pos/.*\.txt$")
|
||||
TEST_NEG_PATTERN = re.compile("aclImdb/test/neg/.*\.txt$")
|
||||
TEST_PATTERN = re.compile("aclImdb/test/.*\.txt$")
|
||||
|
||||
|
||||
class TestIMDB(unittest.TestCase):
|
||||
word_idx = None
|
||||
|
||||
def test_build_dict(self):
|
||||
if self.word_idx == None:
|
||||
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN,
|
||||
150)
|
||||
|
||||
self.assertEqual(len(self.word_idx), 7036)
|
||||
|
||||
def check_dataset(self, dataset, expected_size):
|
||||
if self.word_idx == None:
|
||||
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN,
|
||||
150)
|
||||
|
||||
sum = 0
|
||||
for l in dataset(self.word_idx):
|
||||
self.assertEqual(l[1], sum % 2)
|
||||
sum += 1
|
||||
self.assertEqual(sum, expected_size)
|
||||
|
||||
def test_train(self):
|
||||
self.check_dataset(paddle.v2.dataset.imdb.train, 25000)
|
||||
|
||||
def test_test(self):
|
||||
self.check_dataset(paddle.v2.dataset.imdb.test, 25000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.imikolov
|
||||
import unittest
|
||||
|
||||
WORD_DICT = paddle.v2.dataset.imikolov.build_dict()
|
||||
|
||||
|
||||
class TestMikolov(unittest.TestCase):
|
||||
def check_reader(self, reader, n):
|
||||
for l in reader():
|
||||
self.assertEqual(len(l), n)
|
||||
|
||||
def test_train(self):
|
||||
n = 5
|
||||
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)
|
||||
|
||||
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
|
||||
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
|
||||
'rake regatta rubens sim snack-food ssangyong swapo wachter'
|
||||
first_line = [
|
||||
WORD_DICT.get(ch, WORD_DICT['<unk>'])
|
||||
for ch in first_line.split(' ')
|
||||
]
|
||||
for l in paddle.v2.dataset.imikolov.train(
|
||||
WORD_DICT, n=-1,
|
||||
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
|
||||
read_line = l[0][1:]
|
||||
break
|
||||
self.assertEqual(first_line, read_line)
|
||||
|
||||
def test_test(self):
|
||||
n = 5
|
||||
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)
|
||||
|
||||
first_line = 'consumers may want to move their telephones a little '\
|
||||
'closer to the tv set'
|
||||
first_line = [
|
||||
WORD_DICT.get(ch, WORD_DICT['<unk>'])
|
||||
for ch in first_line.split(' ')
|
||||
]
|
||||
for l in paddle.v2.dataset.imikolov.test(
|
||||
WORD_DICT, n=-1,
|
||||
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
|
||||
read_line = l[0][1:]
|
||||
break
|
||||
self.assertEqual(first_line, read_line)
|
||||
|
||||
def test_total(self):
|
||||
_, idx = zip(*WORD_DICT.items())
|
||||
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.mnist
|
||||
import unittest
|
||||
|
||||
|
||||
class TestMNIST(unittest.TestCase):
|
||||
def check_reader(self, reader):
|
||||
sum = 0
|
||||
label = 0
|
||||
for l in reader():
|
||||
self.assertEqual(l[0].size, 784)
|
||||
if l[1] > label:
|
||||
label = l[1]
|
||||
sum += 1
|
||||
return sum, label
|
||||
|
||||
def test_train(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.mnist.train())
|
||||
self.assertEqual(instances, 60000)
|
||||
self.assertEqual(max_label_value, 9)
|
||||
|
||||
def test_test(self):
|
||||
instances, max_label_value = self.check_reader(
|
||||
paddle.v2.dataset.mnist.test())
|
||||
self.assertEqual(instances, 10000)
|
||||
self.assertEqual(max_label_value, 9)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.mq2007
|
||||
import unittest
|
||||
|
||||
|
||||
class TestMQ2007(unittest.TestCase):
|
||||
def test_pairwise(self):
|
||||
for label, query_left, query_right in paddle.v2.dataset.mq2007.test(
|
||||
format="pairwise"):
|
||||
self.assertEqual(query_left.shape(), (46, ))
|
||||
self.assertEqual(query_right.shape(), (46, ))
|
||||
|
||||
def test_listwise(self):
|
||||
for label_array, query_array in paddle.v2.dataset.mq2007.test(
|
||||
format="listwise"):
|
||||
self.assertEqual(len(label_array), len(query_array))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,55 @@
|
||||
# /usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import nltk
|
||||
import paddle.v2.dataset.sentiment as st
|
||||
from nltk.corpus import movie_reviews
|
||||
|
||||
|
||||
class TestSentimentMethods(unittest.TestCase):
|
||||
def test_get_word_dict(self):
|
||||
word_dict = st.get_word_dict()[0:10]
|
||||
test_word_list = [(u',', 0), (u'the', 1), (u'.', 2), (u'a', 3),
|
||||
(u'and', 4), (u'of', 5), (u'to', 6), (u"'", 7),
|
||||
(u'is', 8), (u'in', 9)]
|
||||
for idx, each in enumerate(word_dict):
|
||||
self.assertEqual(each, test_word_list[idx])
|
||||
self.assertTrue("/root/.cache/paddle/dataset" in nltk.data.path)
|
||||
|
||||
def test_sort_files(self):
|
||||
last_label = ''
|
||||
for sample_file in st.sort_files():
|
||||
current_label = sample_file.split("/")[0]
|
||||
self.assertNotEqual(current_label, last_label)
|
||||
last_label = current_label
|
||||
|
||||
def test_data_set(self):
|
||||
data_set = st.load_sentiment_data()
|
||||
last_label = -1
|
||||
for each in st.test():
|
||||
self.assertNotEqual(each[1], last_label)
|
||||
last_label = each[1]
|
||||
self.assertEqual(len(data_set), st.NUM_TOTAL_INSTANCES)
|
||||
self.assertEqual(len(list(st.train())), st.NUM_TRAINING_INSTANCES)
|
||||
self.assertEqual(
|
||||
len(list(st.test())),
|
||||
(st.NUM_TOTAL_INSTANCES - st.NUM_TRAINING_INSTANCES))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.voc2012
|
||||
import unittest
|
||||
|
||||
|
||||
class TestVOC(unittest.TestCase):
|
||||
def check_reader(self, reader):
|
||||
sum = 0
|
||||
label = 0
|
||||
for l in reader():
|
||||
self.assertEqual(l[0].size, 3 * l[1].size)
|
||||
sum += 1
|
||||
return sum
|
||||
|
||||
def test_train(self):
|
||||
count = self.check_reader(paddle.v2.dataset.voc_seg.train())
|
||||
self.assertEqual(count, 2913)
|
||||
|
||||
def test_test(self):
|
||||
count = self.check_reader(paddle.v2.dataset.voc_seg.test())
|
||||
self.assertEqual(count, 1464)
|
||||
|
||||
def test_val(self):
|
||||
count = self.check_reader(paddle.v2.dataset.voc_seg.val())
|
||||
self.assertEqual(count, 1449)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.v2.dataset.wmt16
|
||||
import unittest
|
||||
|
||||
|
||||
class TestWMT16(unittest.TestCase):
|
||||
def checkout_one_sample(self, sample):
|
||||
# train data has 3 field: source language word indices,
|
||||
# target language word indices, and target next word indices.
|
||||
self.assertEqual(len(sample), 3)
|
||||
|
||||
# test start mark and end mark in source word indices.
|
||||
self.assertEqual(sample[0][0], 0)
|
||||
self.assertEqual(sample[0][-1], 1)
|
||||
|
||||
# test start mask in target word indices
|
||||
self.assertEqual(sample[1][0], 0)
|
||||
|
||||
# test en mask in target next word indices
|
||||
self.assertEqual(sample[2][-1], 1)
|
||||
|
||||
def test_train(self):
|
||||
for idx, sample in enumerate(
|
||||
paddle.v2.dataset.wmt16.train(
|
||||
src_dict_size=100000, trg_dict_size=100000)()):
|
||||
if idx >= 10: break
|
||||
self.checkout_one_sample(sample)
|
||||
|
||||
def test_test(self):
|
||||
for idx, sample in enumerate(
|
||||
paddle.v2.dataset.wmt16.test(
|
||||
src_dict_size=1000, trg_dict_size=1000)()):
|
||||
if idx >= 10: break
|
||||
self.checkout_one_sample(sample)
|
||||
|
||||
def test_val(self):
|
||||
for idx, sample in enumerate(
|
||||
paddle.v2.dataset.wmt16.validation(
|
||||
src_dict_size=1000, trg_dict_size=1000)()):
|
||||
if idx >= 10: break
|
||||
self.checkout_one_sample(sample)
|
||||
|
||||
def test_get_dict(self):
|
||||
dict_size = 1000
|
||||
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
|
||||
self.assertEqual(len(word_dict), dict_size)
|
||||
self.assertEqual(word_dict[0], "<s>")
|
||||
self.assertEqual(word_dict[1], "<e>")
|
||||
self.assertEqual(word_dict[2], "<unk>")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,134 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
UCI Housing dataset.
|
||||
|
||||
This module will download dataset from
|
||||
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and
|
||||
parse training set and test set into paddle reader creators.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import paddle.v2.dataset.common
|
||||
from paddle.v2.parameters import Parameters
|
||||
|
||||
__all__ = ['train', 'test']
|
||||
|
||||
URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data'
|
||||
MD5 = 'd4accdce7a25600298819f8e28e8d593'
|
||||
feature_names = [
|
||||
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
|
||||
'PTRATIO', 'B', 'LSTAT', 'convert'
|
||||
]
|
||||
|
||||
UCI_TRAIN_DATA = None
|
||||
UCI_TEST_DATA = None
|
||||
URL_MODEL = 'https://github.com/PaddlePaddle/book/raw/develop/01.fit_a_line/fit_a_line.tar'
|
||||
MD5_MODEL = '52fc3da8ef3937822fcdd87ee05c0c9b'
|
||||
|
||||
|
||||
def feature_range(maximums, minimums):
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
fig, ax = plt.subplots()
|
||||
feature_num = len(maximums)
|
||||
ax.bar(range(feature_num), maximums - minimums, color='r', align='center')
|
||||
ax.set_title('feature scale')
|
||||
plt.xticks(range(feature_num), feature_names)
|
||||
plt.xlim([-1, feature_num])
|
||||
fig.set_figheight(6)
|
||||
fig.set_figwidth(10)
|
||||
if not os.path.exists('./image'):
|
||||
os.makedirs('./image')
|
||||
fig.savefig('image/ranges.png', dpi=48)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def load_data(filename, feature_num=14, ratio=0.8):
|
||||
global UCI_TRAIN_DATA, UCI_TEST_DATA
|
||||
if UCI_TRAIN_DATA is not None and UCI_TEST_DATA is not None:
|
||||
return
|
||||
|
||||
data = np.fromfile(filename, sep=' ')
|
||||
data = data.reshape(data.shape[0] / feature_num, feature_num)
|
||||
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
|
||||
axis=0) / data.shape[0]
|
||||
feature_range(maximums[:-1], minimums[:-1])
|
||||
for i in xrange(feature_num - 1):
|
||||
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
|
||||
offset = int(data.shape[0] * ratio)
|
||||
UCI_TRAIN_DATA = data[:offset]
|
||||
UCI_TEST_DATA = data[offset:]
|
||||
|
||||
|
||||
def train():
|
||||
"""
|
||||
UCI_HOUSING training set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is features after
|
||||
normalization and price number.
|
||||
|
||||
:return: Training reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
global UCI_TRAIN_DATA
|
||||
load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
|
||||
|
||||
def reader():
|
||||
for d in UCI_TRAIN_DATA:
|
||||
yield d[:-1], d[-1:]
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
UCI_HOUSING test set creator.
|
||||
|
||||
It returns a reader creator, each sample in the reader is features after
|
||||
normalization and price number.
|
||||
|
||||
:return: Test reader creator
|
||||
:rtype: callable
|
||||
"""
|
||||
global UCI_TEST_DATA
|
||||
load_data(paddle.v2.dataset.common.download(URL, 'uci_housing', MD5))
|
||||
|
||||
def reader():
|
||||
for d in UCI_TEST_DATA:
|
||||
yield d[:-1], d[-1:]
|
||||
|
||||
return reader
|
||||
|
||||
|
||||
def model():
|
||||
tar_file = paddle.v2.dataset.common.download(URL_MODEL, 'fit_a_line.tar',
|
||||
MD5_MODEL)
|
||||
with open(tar_file, 'r') as f:
|
||||
parameters = Parameters.from_tar(f)
|
||||
return parameters
|
||||
|
||||
|
||||
def fetch():
|
||||
paddle.v2.dataset.common.download(URL, 'uci_housing', MD5)
|
||||
|
||||
|
||||
def convert(path):
|
||||
"""
|
||||
Converts dataset to recordio format
|
||||
"""
|
||||
paddle.v2.dataset.common.convert(path, train(), 1000, "uci_housing_train")
|
||||
paddle.v2.dataset.common.convert(path, test(), 1000, "uci_houseing_test")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue