parent
644dfd7dd5
commit
1a72a9035b
@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import cPickle as pickle
|
||||
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ['Cifar10', 'Cifar100']
|
||||
|
||||
URL_PREFIX = 'https://dataset.bj.bcebos.com/cifar/'
|
||||
CIFAR10_URL = URL_PREFIX + 'cifar-10-python.tar.gz'
|
||||
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
|
||||
CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
|
||||
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
|
||||
|
||||
MODE_FLAG_MAP = {
|
||||
'train10': 'data_batch',
|
||||
'test10': 'test_batch',
|
||||
'train100': 'train',
|
||||
'test100': 'test'
|
||||
}
|
||||
|
||||
|
||||
class Cifar10(Dataset):
|
||||
"""
|
||||
Implementation of `Cifar-10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
|
||||
dataset, which has 10 categories.
|
||||
|
||||
Args:
|
||||
data_file(str): path to data file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train', 'test' mode. Default 'train'.
|
||||
transform(callable): transform to perform on image, None for on transform.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of cifar-10 dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import Cifar10
|
||||
from paddle.incubate.hapi.vision.transforms import Normalize
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
self.fc = paddle.nn.Linear(3072, 10, act='softmax')
|
||||
|
||||
def forward(self, image, label):
|
||||
image = paddle.reshape(image, (3, -1))
|
||||
return self.fc(image), label
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
normalize = Normalize(mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])
|
||||
cifar10 = Cifar10(mode='train', transform=normalize)
|
||||
|
||||
for i in range(10):
|
||||
image, label = cifar10[i]
|
||||
image = paddle.to_tensor(image)
|
||||
label = paddle.to_tensor(label)
|
||||
|
||||
model = SimpleNet()
|
||||
image, label = model(image, label)
|
||||
print(image.numpy().shape, label.numpy().shape)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
transform=None,
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'test', 'train', 'test'], \
|
||||
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self._init_url_md5_flag()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(
|
||||
data_file, self.data_url, self.data_md5, 'cifar', download)
|
||||
|
||||
self.transform = transform
|
||||
|
||||
# read dataset into memory
|
||||
self._load_data()
|
||||
|
||||
def _init_url_md5_flag(self):
|
||||
self.data_url = CIFAR10_URL
|
||||
self.data_md5 = CIFAR10_MD5
|
||||
self.flag = MODE_FLAG_MAP[self.mode + '10']
|
||||
|
||||
def _load_data(self):
|
||||
self.data = []
|
||||
with tarfile.open(self.data_file, mode='r') as f:
|
||||
names = (each_item.name for each_item in f
|
||||
if self.flag in each_item.name)
|
||||
|
||||
for name in names:
|
||||
if six.PY2:
|
||||
batch = pickle.load(f.extractfile(name))
|
||||
else:
|
||||
batch = pickle.load(f.extractfile(name), encoding='bytes')
|
||||
|
||||
data = batch[six.b('data')]
|
||||
labels = batch.get(
|
||||
six.b('labels'), batch.get(six.b('fine_labels'), None))
|
||||
assert labels is not None
|
||||
for sample, label in six.moves.zip(data, labels):
|
||||
self.data.append((sample, label))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image, label = self.data[idx]
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class Cifar100(Cifar10):
|
||||
"""
|
||||
Implementation of `Cifar-100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_
|
||||
dataset, which has 100 categories.
|
||||
|
||||
Args:
|
||||
data_file(str): path to data file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train', 'test' mode. Default 'train'.
|
||||
transform(callable): transform to perform on image, None for on transform.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of cifar-100 dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import Cifar100
|
||||
from paddle.incubate.hapi.vision.transforms import Normalize
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
self.fc = paddle.nn.Linear(3072, 100, act='softmax')
|
||||
|
||||
def forward(self, image, label):
|
||||
image = paddle.reshape(image, (3, -1))
|
||||
return self.fc(image), label
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
normalize = Normalize(mean=[0.5, 0.5, 0.5],
|
||||
std=[0.5, 0.5, 0.5])
|
||||
cifar100 = Cifar100(mode='train', transform=normalize)
|
||||
|
||||
for i in range(10):
|
||||
image, label = cifar100[i]
|
||||
image = paddle.to_tensor(image)
|
||||
label = paddle.to_tensor(label)
|
||||
|
||||
model = SimpleNet()
|
||||
image, label = model(image, label)
|
||||
print(image.numpy().shape, label.numpy().shape)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
transform=None,
|
||||
download=True):
|
||||
super(Cifar100, self).__init__(data_file, mode, transform, download)
|
||||
|
||||
def _init_url_md5_flag(self):
|
||||
self.data_url = CIFAR100_URL
|
||||
self.data_md5 = CIFAR100_MD5
|
||||
self.flag = MODE_FLAG_MAP[self.mode + '100']
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import six
|
||||
import string
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import collections
|
||||
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ['Imdb']
|
||||
|
||||
URL = 'https://dataset.bj.bcebos.com/imdb%2FaclImdb_v1.tar.gz'
|
||||
MD5 = '7c2ac02c03563afcf9b574c7e56c153a'
|
||||
|
||||
|
||||
class Imdb(Dataset):
|
||||
"""
|
||||
Implementation of `IMDB <https://www.imdb.com/interfaces/>`_ dataset.
|
||||
|
||||
Args:
|
||||
data_file(str): path to data tar file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train' 'test' mode. Default 'train'.
|
||||
cutoff(int): cutoff number for building word dictionary. Default 150.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of IMDB dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import Imdb
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, doc, label):
|
||||
return paddle.sum(doc), label
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
imdb = Imdb(mode='train')
|
||||
|
||||
for i in range(10):
|
||||
doc, label = imdb[i]
|
||||
doc = paddle.to_tensor(doc)
|
||||
label = paddle.to_tensor(label)
|
||||
|
||||
model = SimpleNet()
|
||||
image, label = model(doc, label)
|
||||
print(doc.numpy().shape, label.numpy().shape)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_file=None, mode='train', cutoff=150, download=True):
|
||||
assert mode.lower() in ['train', 'test'], \
|
||||
"mode should be 'train', 'test', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(data_file, URL, MD5,
|
||||
'imdb', download)
|
||||
|
||||
# Build a word dictionary from the corpus
|
||||
self.word_idx = self._build_work_dict(cutoff)
|
||||
|
||||
# read dataset into memory
|
||||
self._load_anno()
|
||||
|
||||
def _build_work_dict(self, cutoff):
|
||||
word_freq = collections.defaultdict(int)
|
||||
pattern = re.compile("aclImdb/((train)|(test))/((pos)|(neg))/.*\.txt$")
|
||||
for doc in self._tokenize(pattern):
|
||||
for word in doc:
|
||||
word_freq[word] += 1
|
||||
|
||||
# Not sure if we should prune less-frequent words here.
|
||||
word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
|
||||
|
||||
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
||||
words, _ = list(zip(*dictionary))
|
||||
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
|
||||
word_idx['<unk>'] = len(words)
|
||||
return word_idx
|
||||
|
||||
def _tokenize(self, pattern):
|
||||
data = []
|
||||
with tarfile.open(self.data_file) as tarf:
|
||||
tf = tarf.next()
|
||||
while tf != None:
|
||||
if bool(pattern.match(tf.name)):
|
||||
# newline and punctuations removal and ad-hoc tokenization.
|
||||
data.append(
|
||||
tarf.extractfile(tf).read().rstrip(six.b("\n\r"))
|
||||
.translate(None, six.b(string.punctuation)).lower(
|
||||
).split())
|
||||
tf = tarf.next()
|
||||
|
||||
return data
|
||||
|
||||
def _load_anno(self):
|
||||
pos_pattern = re.compile("aclImdb/{}/pos/.*\.txt$".format(self.mode))
|
||||
neg_pattern = re.compile("aclImdb/{}/neg/.*\.txt$".format(self.mode))
|
||||
|
||||
UNK = self.word_idx['<unk>']
|
||||
|
||||
self.docs = []
|
||||
self.labels = []
|
||||
for doc in self._tokenize(pos_pattern):
|
||||
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
|
||||
self.labels.append(0)
|
||||
for doc in self._tokenize(neg_pattern):
|
||||
self.docs.append([self.word_idx.get(w, UNK) for w in doc])
|
||||
self.labels.append(1)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (np.array(self.docs[idx]), np.array([self.labels[idx]]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.docs)
|
@ -0,0 +1,171 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import collections
|
||||
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ['Imikolov']
|
||||
|
||||
URL = 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz'
|
||||
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
|
||||
|
||||
|
||||
class Imikolov(Dataset):
|
||||
"""
|
||||
Implementation of imikolov dataset.
|
||||
|
||||
Args:
|
||||
data_file(str): path to data tar file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
data_type(str): 'NGRAM' or 'SEQ'. Default 'NGRAM'.
|
||||
window_size(int): sliding window size for 'NGRAM' data. Default -1.
|
||||
mode(str): 'train' 'test' mode. Default 'train'.
|
||||
min_word_freq(int): minimal word frequence for building word dictionary. Default 50.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of imikolov dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import Imikolov
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, src, trg):
|
||||
return paddle.sum(src), paddle.sum(trg)
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
imikolov = Imikolov(mode='train', data_type='SEQ', window_size=2)
|
||||
|
||||
for i in range(10):
|
||||
src, trg = imikolov[i]
|
||||
src = paddle.to_tensor(src)
|
||||
trg = paddle.to_tensor(trg)
|
||||
|
||||
model = SimpleNet()
|
||||
src, trg = model(src, trg)
|
||||
print(src.numpy().shape, trg.numpy().shape)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
data_type='NGRAM',
|
||||
window_size=-1,
|
||||
mode='train',
|
||||
min_word_freq=50,
|
||||
download=True):
|
||||
assert data_type.upper() in ['NGRAM', 'SEQ'], \
|
||||
"data type should be 'NGRAM', 'SEQ', but got {}".format(data_type)
|
||||
self.data_type = data_type.upper()
|
||||
|
||||
assert mode.lower() in ['train', 'test'], \
|
||||
"mode should be 'train', 'test', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self.window_size = window_size
|
||||
self.min_word_freq = min_word_freq
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically disabled"
|
||||
self.data_file = _check_exists_and_download(data_file, URL, MD5,
|
||||
'imikolov', download)
|
||||
|
||||
# Build a word dictionary from the corpus
|
||||
self.word_idx = self._build_work_dict(min_word_freq)
|
||||
|
||||
# read dataset into memory
|
||||
self._load_anno()
|
||||
|
||||
def word_count(self, f, word_freq=None):
|
||||
if word_freq is None:
|
||||
word_freq = collections.defaultdict(int)
|
||||
|
||||
for l in f:
|
||||
for w in l.strip().split():
|
||||
word_freq[w] += 1
|
||||
word_freq['<s>'] += 1
|
||||
word_freq['<e>'] += 1
|
||||
|
||||
return word_freq
|
||||
|
||||
def _build_work_dict(self, cutoff):
|
||||
train_filename = './simple-examples/data/ptb.train.txt'
|
||||
test_filename = './simple-examples/data/ptb.valid.txt'
|
||||
with tarfile.open(self.data_file) as tf:
|
||||
trainf = tf.extractfile(train_filename)
|
||||
testf = tf.extractfile(test_filename)
|
||||
word_freq = self.word_count(testf, self.word_count(trainf))
|
||||
if '<unk>' in word_freq:
|
||||
# remove <unk> for now, since we will set it as last index
|
||||
del word_freq['<unk>']
|
||||
|
||||
word_freq = [
|
||||
x for x in six.iteritems(word_freq) if x[1] > self.min_word_freq
|
||||
]
|
||||
|
||||
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
||||
words, _ = list(zip(*word_freq_sorted))
|
||||
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
|
||||
word_idx['<unk>'] = len(words)
|
||||
|
||||
return word_idx
|
||||
|
||||
def _load_anno(self):
|
||||
self.data = []
|
||||
with tarfile.open(self.data_file) as tf:
|
||||
filename = './simple-examples/data/ptb.{}.txt'.format(self.mode)
|
||||
f = tf.extractfile(filename)
|
||||
|
||||
UNK = self.word_idx['<unk>']
|
||||
for l in f:
|
||||
if self.data_type == 'NGRAM':
|
||||
assert self.window_size > -1, 'Invalid gram length'
|
||||
l = ['<s>'] + l.strip().split() + ['<e>']
|
||||
if len(l) >= self.window_size:
|
||||
l = [self.word_idx.get(w, UNK) for w in l]
|
||||
for i in six.moves.range(self.window_size, len(l) + 1):
|
||||
self.data.append(tuple(l[i - self.window_size:i]))
|
||||
elif self.data_type == 'SEQ':
|
||||
l = l.strip().split()
|
||||
l = [self.word_idx.get(w, UNK) for w in l]
|
||||
src_seq = [self.word_idx['<s>']] + l
|
||||
trg_seq = l + [self.word_idx['<e>']]
|
||||
if self.window_size > 0 and len(src_seq) > self.window_size:
|
||||
continue
|
||||
self.data.append((src_seq, trg_seq))
|
||||
else:
|
||||
assert False, 'Unknow data type'
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return tuple([np.array(d) for d in self.data[idx]])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
@ -0,0 +1,173 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import six
|
||||
import numpy as np
|
||||
import collections
|
||||
import nltk
|
||||
from nltk.corpus import movie_reviews
|
||||
import zipfile
|
||||
from functools import cmp_to_key
|
||||
from itertools import chain
|
||||
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
|
||||
__all__ = ['MovieReviews']
|
||||
|
||||
URL = "https://corpora.bj.bcebos.com/movie_reviews%2Fmovie_reviews.zip"
|
||||
MD5 = '155de2b77c6834dd8eea7cbe88e93acb'
|
||||
|
||||
NUM_TRAINING_INSTANCES = 1600
|
||||
NUM_TOTAL_INSTANCES = 2000
|
||||
|
||||
|
||||
class MovieReviews(Dataset):
|
||||
"""
|
||||
Implementation of `NLTK movie reviews <http://www.nltk.org/nltk_data/>`_ dataset.
|
||||
|
||||
Args:
|
||||
data_file(str): path to data tar file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train' 'test' mode. Default 'train'.
|
||||
download(bool): whether auto download cifar dataset if
|
||||
:attr:`data_file` unset. Default True.
|
||||
|
||||
Returns:
|
||||
Dataset: instance of movie reviews dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import MovieReviews
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, word, category):
|
||||
return paddle.sum(word), category
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
movie_reviews = MovieReviews(mode='train')
|
||||
|
||||
for i in range(10):
|
||||
word_list, category = movie_reviews[i]
|
||||
word_list = paddle.to_tensor(word_list)
|
||||
category = paddle.to_tensor(category)
|
||||
|
||||
model = SimpleNet()
|
||||
word_list, category = model(word_list, category)
|
||||
print(word_list.numpy().shape, category.numpy())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, mode='train'):
|
||||
assert mode.lower() in ['train', 'test'], \
|
||||
"mode should be 'train', 'test', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self._download_data_if_not_yet()
|
||||
|
||||
# read dataset into memory
|
||||
self._load_sentiment_data()
|
||||
|
||||
def _get_word_dict(self):
|
||||
"""
|
||||
Sorted the words by the frequency of words which occur in sample
|
||||
:return:
|
||||
words_freq_sorted
|
||||
"""
|
||||
words_freq_sorted = list()
|
||||
word_freq_dict = collections.defaultdict(int)
|
||||
|
||||
for category in movie_reviews.categories():
|
||||
for field in movie_reviews.fileids(category):
|
||||
for words in movie_reviews.words(field):
|
||||
word_freq_dict[words] += 1
|
||||
words_sort_list = list(six.iteritems(word_freq_dict))
|
||||
words_sort_list.sort(key=cmp_to_key(lambda a, b: b[1] - a[1]))
|
||||
for index, word in enumerate(words_sort_list):
|
||||
words_freq_sorted.append((word[0], index))
|
||||
return words_freq_sorted
|
||||
|
||||
def _sort_files(self):
|
||||
"""
|
||||
Sorted the sample for cross reading the sample
|
||||
:return:
|
||||
files_list
|
||||
"""
|
||||
files_list = list()
|
||||
neg_file_list = movie_reviews.fileids('neg')
|
||||
pos_file_list = movie_reviews.fileids('pos')
|
||||
files_list = list(
|
||||
chain.from_iterable(list(zip(neg_file_list, pos_file_list))))
|
||||
return files_list
|
||||
|
||||
def _load_sentiment_data(self):
|
||||
"""
|
||||
Load the data set
|
||||
:return:
|
||||
data_set
|
||||
"""
|
||||
self.data = []
|
||||
words_ids = dict(self._get_word_dict())
|
||||
for sample_file in self._sort_files():
|
||||
words_list = list()
|
||||
category = 0 if 'neg' in sample_file else 1
|
||||
for word in movie_reviews.words(sample_file):
|
||||
words_list.append(words_ids[word.lower()])
|
||||
self.data.append((words_list, category))
|
||||
|
||||
def _download_data_if_not_yet(self):
|
||||
"""
|
||||
Download the data set, if the data set is not download.
|
||||
"""
|
||||
try:
|
||||
# download and extract movie_reviews.zip
|
||||
paddle.dataset.common.download(
|
||||
URL, 'corpora', md5sum=MD5, save_name='movie_reviews.zip')
|
||||
path = os.path.join(paddle.dataset.common.DATA_HOME, 'corpora')
|
||||
filename = os.path.join(path, 'movie_reviews.zip')
|
||||
zip_file = zipfile.ZipFile(filename)
|
||||
zip_file.extractall(path)
|
||||
zip_file.close()
|
||||
# make sure that nltk can find the data
|
||||
if paddle.dataset.common.DATA_HOME not in nltk.data.path:
|
||||
nltk.data.path.append(paddle.dataset.common.DATA_HOME)
|
||||
movie_reviews.categories()
|
||||
except LookupError:
|
||||
print("Downloading movie_reviews data set, please wait.....")
|
||||
nltk.download(
|
||||
'movie_reviews', download_dir=paddle.dataset.common.DATA_HOME)
|
||||
print("Download data set success.....")
|
||||
print("Path is " + nltk.data.find('corpora/movie_reviews').path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.mode == 'test':
|
||||
idx += NUM_TRAINING_INSTANCES
|
||||
data = self.data[idx]
|
||||
return np.array(data[0]), np.array(data[1])
|
||||
|
||||
def __len__(self):
|
||||
if self.mode == 'train':
|
||||
return NUM_TRAINING_INSTANCES
|
||||
else:
|
||||
return NUM_TOTAL_INSTANCES - NUM_TRAINING_INSTANCES
|
@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import zipfile
|
||||
import re
|
||||
import random
|
||||
import functools
|
||||
import six
|
||||
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import paddle.compat as cpt
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ['Movielens']
|
||||
|
||||
age_table = [1, 18, 25, 35, 45, 50, 56]
|
||||
|
||||
URL = 'https://dataset.bj.bcebos.com/movielens%2Fml-1m.zip'
|
||||
MD5 = 'c4d9eecfca2ab87c1945afe126590906'
|
||||
|
||||
|
||||
class MovieInfo(object):
|
||||
"""
|
||||
Movie id, title and categories information are stored in MovieInfo.
|
||||
"""
|
||||
|
||||
def __init__(self, index, categories, title):
|
||||
self.index = int(index)
|
||||
self.categories = categories
|
||||
self.title = title
|
||||
|
||||
def value(self, categories_dict, movie_title_dict):
|
||||
"""
|
||||
Get information from a movie.
|
||||
"""
|
||||
return [[self.index], [categories_dict[c] for c in self.categories],
|
||||
[movie_title_dict[w.lower()] for w in self.title.split()]]
|
||||
|
||||
def __str__(self):
|
||||
return "<MovieInfo id(%d), title(%s), categories(%s)>" % (
|
||||
self.index, self.title, self.categories)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class UserInfo(object):
|
||||
"""
|
||||
User id, gender, age, and job information are stored in UserInfo.
|
||||
"""
|
||||
|
||||
def __init__(self, index, gender, age, job_id):
|
||||
self.index = int(index)
|
||||
self.is_male = gender == 'M'
|
||||
self.age = age_table.index(int(age))
|
||||
self.job_id = int(job_id)
|
||||
|
||||
def value(self):
|
||||
"""
|
||||
Get information from a user.
|
||||
"""
|
||||
return [[self.index], [0 if self.is_male else 1], [self.age],
|
||||
[self.job_id]]
|
||||
|
||||
def __str__(self):
|
||||
return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
|
||||
self.index, "M"
|
||||
if self.is_male else "F", age_table[self.age], self.job_id)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class Movielens(Dataset):
|
||||
"""
|
||||
Implementation of `Movielens 1-M <https://grouplens.org/datasets/movielens/1m/>`_ dataset.
|
||||
|
||||
Args:
|
||||
data_file(str): path to data tar file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train' or 'test' mode. Default 'train'.
|
||||
test_ratio(float): split ratio for test sample. Default 0.1.
|
||||
rand_seed(int): random seed. Default 0.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of Movielens 1-M dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import Movielens
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, category, title, rating):
|
||||
return paddle.sum(category), paddle.sum(title), paddle.sum(rating)
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
movielens = Movielens(mode='train')
|
||||
|
||||
for i in range(10):
|
||||
category, title, rating = movielens[i][-3:]
|
||||
category = paddle.to_tensor(category)
|
||||
title = paddle.to_tensor(title)
|
||||
rating = paddle.to_tensor(rating)
|
||||
|
||||
model = SimpleNet()
|
||||
category, title, rating = model(category, title, rating)
|
||||
print(category.numpy().shape, title.numpy().shape, rating.numpy().shape)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
test_ratio=0.1,
|
||||
rand_seed=0,
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'test'], \
|
||||
"mode should be 'train', 'test', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(data_file, URL, MD5,
|
||||
'sentiment', download)
|
||||
|
||||
self.test_ratio = test_ratio
|
||||
self.rand_seed = rand_seed
|
||||
|
||||
np.random.seed(rand_seed)
|
||||
self._load_meta_info()
|
||||
self._load_data()
|
||||
|
||||
def _load_meta_info(self):
|
||||
pattern = re.compile(r'^(.*)\((\d+)\)$')
|
||||
self.movie_info = dict()
|
||||
self.movie_title_dict = dict()
|
||||
self.categories_dict = dict()
|
||||
self.user_info = dict()
|
||||
with zipfile.ZipFile(self.data_file) as package:
|
||||
for info in package.infolist():
|
||||
assert isinstance(info, zipfile.ZipInfo)
|
||||
title_word_set = set()
|
||||
categories_set = set()
|
||||
with package.open('ml-1m/movies.dat') as movie_file:
|
||||
for i, line in enumerate(movie_file):
|
||||
line = cpt.to_text(line, encoding='latin')
|
||||
movie_id, title, categories = line.strip().split('::')
|
||||
categories = categories.split('|')
|
||||
for c in categories:
|
||||
categories_set.add(c)
|
||||
title = pattern.match(title).group(1)
|
||||
self.movie_info[int(movie_id)] = MovieInfo(
|
||||
index=movie_id, categories=categories, title=title)
|
||||
for w in title.split():
|
||||
title_word_set.add(w.lower())
|
||||
|
||||
for i, w in enumerate(title_word_set):
|
||||
self.movie_title_dict[w] = i
|
||||
|
||||
for i, c in enumerate(categories_set):
|
||||
self.categories_dict[c] = i
|
||||
|
||||
with package.open('ml-1m/users.dat') as user_file:
|
||||
for line in user_file:
|
||||
line = cpt.to_text(line, encoding='latin')
|
||||
uid, gender, age, job, _ = line.strip().split("::")
|
||||
self.user_info[int(uid)] = UserInfo(
|
||||
index=uid, gender=gender, age=age, job_id=job)
|
||||
|
||||
def _load_data(self):
|
||||
self.data = []
|
||||
is_test = self.mode == 'test'
|
||||
with zipfile.ZipFile(self.data_file) as package:
|
||||
with package.open('ml-1m/ratings.dat') as rating:
|
||||
for line in rating:
|
||||
line = cpt.to_text(line, encoding='latin')
|
||||
if (np.random.random() < self.test_ratio) == is_test:
|
||||
uid, mov_id, rating, _ = line.strip().split("::")
|
||||
uid = int(uid)
|
||||
mov_id = int(mov_id)
|
||||
rating = float(rating) * 2 - 5.0
|
||||
|
||||
mov = self.movie_info[mov_id]
|
||||
usr = self.user_info[uid]
|
||||
self.data.append(usr.value() + \
|
||||
mov.value(self.categories_dict, self.movie_title_dict) + \
|
||||
[[rating]])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.data[idx]
|
||||
return tuple([np.array(d) for d in data])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
@ -0,0 +1,110 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import numpy as np
|
||||
|
||||
import paddle.dataset.common
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ["UCIHousing"]
|
||||
|
||||
URL = 'http://paddlemodels.bj.bcebos.com/uci_housing/housing.data'
|
||||
MD5 = 'd4accdce7a25600298819f8e28e8d593'
|
||||
feature_names = [
|
||||
'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX',
|
||||
'PTRATIO', 'B', 'LSTAT'
|
||||
]
|
||||
|
||||
|
||||
class UCIHousing(Dataset):
|
||||
"""
|
||||
Implementation of `UCI housing <https://archive.ics.uci.edu/ml/datasets/Housing>`_
|
||||
dataset
|
||||
|
||||
Args:
|
||||
data_file(str): path to data file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train' or 'test' mode. Default 'train'.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of UCI housing dataset.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import UCIHousing
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, feature, target):
|
||||
return paddle.sum(feature), target
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
uci_housing = UCIHousing(mode='train')
|
||||
|
||||
for i in range(10):
|
||||
feature, target = uci_housing[i]
|
||||
feature = paddle.to_tensor(feature)
|
||||
target = paddle.to_tensor(target)
|
||||
|
||||
model = SimpleNet()
|
||||
feature, target = model(feature, target)
|
||||
print(feature.numpy().shape, target.numpy())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data_file=None, mode='train', download=True):
|
||||
assert mode.lower() in ['train', 'test'], \
|
||||
"mode should be 'train' or 'test', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(data_file, URL, MD5,
|
||||
'uci_housing', download)
|
||||
|
||||
# read dataset into memory
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self, feature_num=14, ratio=0.8):
|
||||
data = np.fromfile(self.data_file, sep=' ')
|
||||
data = data.reshape(data.shape[0] // feature_num, feature_num)
|
||||
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
|
||||
axis=0) / data.shape[0]
|
||||
for i in six.moves.range(feature_num - 1):
|
||||
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
|
||||
offset = int(data.shape[0] * ratio)
|
||||
if self.mode == 'train':
|
||||
self.data = data[:offset]
|
||||
elif self.mode == 'test':
|
||||
self.data = data[offset:]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = self.data[idx]
|
||||
return np.array(data[:-1]), np.array(data[-1:])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import io
|
||||
import tarfile
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ["VOC2012"]
|
||||
|
||||
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\
|
||||
VOCtrainval_11-May-2012.tar'
|
||||
|
||||
VOC_MD5 = '131da710f39b47a43fdfa256cbc11976'
|
||||
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
|
||||
DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
|
||||
LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'
|
||||
|
||||
CACHE_DIR = 'voc2012'
|
||||
|
||||
MODE_FLAG_MAP = {'train': 'trainval', 'test': 'train', 'valid': "val"}
|
||||
|
||||
|
||||
class VOC2012(Dataset):
|
||||
"""
|
||||
Implementation of `VOC2012 <http://host.robots.ox.ac.uk/pascal/VOC/voc2012/>`_ dataset
|
||||
|
||||
Args:
|
||||
data_file(str): path to data file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import VOC2012
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, image, label):
|
||||
return paddle.sum(image), label
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
voc2012 = VOC2012(mode='train')
|
||||
|
||||
for i in range(10):
|
||||
image, label= voc2012[i]
|
||||
image = paddle.cast(paddle.to_tensor(image), 'float32')
|
||||
label = paddle.to_tensor(label)
|
||||
|
||||
model = SimpleNet()
|
||||
image, label= model(image, label)
|
||||
print(image.numpy().shape, label.numpy().shape)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
transform=None,
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'valid', 'test'], \
|
||||
"mode should be 'train', 'valid' or 'test', but got {}".format(mode)
|
||||
self.flag = MODE_FLAG_MAP[mode.lower()]
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(
|
||||
data_file, VOC_URL, VOC_MD5, CACHE_DIR, download)
|
||||
self.transform = transform
|
||||
|
||||
# read dataset into memory
|
||||
self._load_anno()
|
||||
|
||||
def _load_anno(self):
|
||||
self.name2mem = {}
|
||||
self.data_tar = tarfile.open(self.data_file)
|
||||
for ele in self.data_tar.getmembers():
|
||||
self.name2mem[ele.name] = ele
|
||||
|
||||
set_file = SET_FILE.format(self.flag)
|
||||
sets = self.data_tar.extractfile(self.name2mem[set_file])
|
||||
|
||||
self.data = []
|
||||
self.labels = []
|
||||
|
||||
for line in sets:
|
||||
line = line.strip()
|
||||
data = DATA_FILE.format(line.decode('utf-8'))
|
||||
label = LABEL_FILE.format(line.decode('utf-8'))
|
||||
self.data.append(data)
|
||||
self.labels.append(label)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data_file = self.data[idx]
|
||||
label_file = self.labels[idx]
|
||||
|
||||
data = self.data_tar.extractfile(self.name2mem[data_file]).read()
|
||||
label = self.data_tar.extractfile(self.name2mem[label_file]).read()
|
||||
data = Image.open(io.BytesIO(data))
|
||||
label = Image.open(io.BytesIO(label))
|
||||
data = np.array(data)
|
||||
label = np.array(label)
|
||||
if self.transform is not None:
|
||||
data = self.transform(data)
|
||||
return data, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
@ -0,0 +1,179 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import gzip
|
||||
|
||||
from paddle.io import Dataset
|
||||
import paddle.compat as cpt
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ['WMT14']
|
||||
|
||||
URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
|
||||
'cslm_joint_paper/data/dev+test.tgz')
|
||||
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
|
||||
# this is a small set of data for test. The original data is too large and
|
||||
# will be add later.
|
||||
URL_TRAIN = ('http://paddlemodels.bj.bcebos.com/wmt/wmt14.tgz')
|
||||
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
|
||||
|
||||
START = "<s>"
|
||||
END = "<e>"
|
||||
UNK = "<unk>"
|
||||
UNK_IDX = 2
|
||||
|
||||
|
||||
class WMT14(Dataset):
|
||||
"""
|
||||
Implementation of `WMT14 <http://www.statmt.org/wmt14/>`_ test dataset.
|
||||
The original WMT14 dataset is too large and a small set of data for set is
|
||||
provided. This module will download dataset from
|
||||
http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz
|
||||
|
||||
Args:
|
||||
data_file(str): path to data tar file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train', 'test' or 'gen'. Default 'train'
|
||||
dict_size(int): word dictionary size. Default -1.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of WMT14 dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import WMT14
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, src_ids, trg_ids, trg_ids_next):
|
||||
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
wmt14 = WMT14(mode='train', dict_size=50)
|
||||
|
||||
for i in range(10):
|
||||
src_ids, trg_ids, trg_ids_next = wmt14[i]
|
||||
src_ids = paddle.to_tensor(src_ids)
|
||||
trg_ids = paddle.to_tensor(trg_ids)
|
||||
trg_ids_next = paddle.to_tensor(trg_ids_next)
|
||||
|
||||
model = SimpleNet()
|
||||
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
|
||||
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
dict_size=-1,
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'test', 'gen'], \
|
||||
"mode should be 'train', 'test' or 'gen', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(
|
||||
data_file, URL_TRAIN, MD5_TRAIN, 'wmt14', download)
|
||||
|
||||
# read dataset into memory
|
||||
assert dict_size > 0, "dict_size should be set as positive number"
|
||||
self.dict_size = dict_size
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self):
|
||||
def __to_dict(fd, size):
|
||||
out_dict = dict()
|
||||
for line_count, line in enumerate(fd):
|
||||
if line_count < size:
|
||||
out_dict[cpt.to_text(line.strip())] = line_count
|
||||
else:
|
||||
break
|
||||
return out_dict
|
||||
|
||||
self.src_ids = []
|
||||
self.trg_ids = []
|
||||
self.trg_ids_next = []
|
||||
with tarfile.open(self.data_file, mode='r') as f:
|
||||
names = [
|
||||
each_item.name for each_item in f
|
||||
if each_item.name.endswith("src.dict")
|
||||
]
|
||||
assert len(names) == 1
|
||||
self.src_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
|
||||
names = [
|
||||
each_item.name for each_item in f
|
||||
if each_item.name.endswith("trg.dict")
|
||||
]
|
||||
assert len(names) == 1
|
||||
self.trg_dict = __to_dict(f.extractfile(names[0]), self.dict_size)
|
||||
|
||||
file_name = "{}/{}".format(self.mode, self.mode)
|
||||
names = [
|
||||
each_item.name for each_item in f
|
||||
if each_item.name.endswith(file_name)
|
||||
]
|
||||
for name in names:
|
||||
for line in f.extractfile(name):
|
||||
line = cpt.to_text(line)
|
||||
line_split = line.strip().split('\t')
|
||||
if len(line_split) != 2:
|
||||
continue
|
||||
src_seq = line_split[0] # one source sequence
|
||||
src_words = src_seq.split()
|
||||
src_ids = [
|
||||
self.src_dict.get(w, UNK_IDX)
|
||||
for w in [START] + src_words + [END]
|
||||
]
|
||||
|
||||
trg_seq = line_split[1] # one target sequence
|
||||
trg_words = trg_seq.split()
|
||||
trg_ids = [self.trg_dict.get(w, UNK_IDX) for w in trg_words]
|
||||
|
||||
# remove sequence whose length > 80 in training mode
|
||||
if len(src_ids) > 80 or len(trg_ids) > 80:
|
||||
continue
|
||||
trg_ids_next = trg_ids + [self.trg_dict[END]]
|
||||
trg_ids = [self.trg_dict[START]] + trg_ids
|
||||
|
||||
self.src_ids.append(src_ids)
|
||||
self.trg_ids.append(trg_ids)
|
||||
self.trg_ids_next.append(trg_ids_next)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (np.array(self.src_ids[idx]), np.array(self.trg_ids[idx]),
|
||||
np.array(self.trg_ids_next[idx]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_ids)
|
||||
|
||||
def get_dict(self, reverse=False):
|
||||
if reverse:
|
||||
src_dict = {v: k for k, v in six.iteritems(src_dict)}
|
||||
trg_dict = {v: k for k, v in six.iteritems(trg_dict)}
|
||||
return src_dict, trg_dict
|
@ -0,0 +1,244 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import six
|
||||
import tarfile
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
import paddle
|
||||
from paddle.io import Dataset
|
||||
import paddle.compat as cpt
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ['WMT16']
|
||||
|
||||
DATA_URL = ("http://paddlemodels.bj.bcebos.com/wmt/wmt16.tar.gz")
|
||||
DATA_MD5 = "0c38be43600334966403524a40dcd81e"
|
||||
|
||||
TOTAL_EN_WORDS = 11250
|
||||
TOTAL_DE_WORDS = 19220
|
||||
|
||||
START_MARK = "<s>"
|
||||
END_MARK = "<e>"
|
||||
UNK_MARK = "<unk>"
|
||||
|
||||
|
||||
class WMT16(Dataset):
|
||||
"""
|
||||
Implementation of `WMT16 <http://www.statmt.org/wmt16/>`_ test dataset.
|
||||
ACL2016 Multimodal Machine Translation. Please see this website for more
|
||||
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
|
||||
|
||||
If you use the dataset created for your task, please cite the following paper:
|
||||
Multi30K: Multilingual English-German Image Descriptions.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
@article{elliott-EtAl:2016:VL16,
|
||||
author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.},
|
||||
title = {Multi30K: Multilingual English-German Image Descriptions},
|
||||
booktitle = {Proceedings of the 6th Workshop on Vision and Language},
|
||||
year = {2016},
|
||||
pages = {70--74},
|
||||
year = 2016
|
||||
}
|
||||
|
||||
Args:
|
||||
data_file(str): path to data tar file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
mode(str): 'train', 'test' or 'val'. Default 'train'
|
||||
src_dict_size(int): word dictionary size for source language word. Default -1.
|
||||
trg_dict_size(int): word dictionary size for target language word. Default -1.
|
||||
lang(str): source language, 'en' or 'de'. Default 'en'.
|
||||
download(bool): whether to download dataset automatically if
|
||||
:attr:`data_file` is not set. Default True
|
||||
|
||||
Returns:
|
||||
Dataset: instance of WMT16 dataset
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddle.incubate.hapi.datasets import WMT16
|
||||
|
||||
class SimpleNet(paddle.nn.Layer):
|
||||
def __init__(self):
|
||||
super(SimpleNet, self).__init__()
|
||||
|
||||
def forward(self, src_ids, trg_ids, trg_ids_next):
|
||||
return paddle.sum(src_ids), paddle.sum(trg_ids), paddle.sum(trg_ids_next)
|
||||
|
||||
paddle.disable_static()
|
||||
|
||||
wmt16 = WMT16(mode='train', src_dict_size=50, trg_dict_size=50)
|
||||
|
||||
for i in range(10):
|
||||
src_ids, trg_ids, trg_ids_next = wmt16[i]
|
||||
src_ids = paddle.to_tensor(src_ids)
|
||||
trg_ids = paddle.to_tensor(trg_ids)
|
||||
trg_ids_next = paddle.to_tensor(trg_ids_next)
|
||||
|
||||
model = SimpleNet()
|
||||
src_ids, trg_ids, trg_ids_next = model(src_ids, trg_ids, trg_ids_next)
|
||||
print(src_ids.numpy(), trg_ids.numpy(), trg_ids_next.numpy())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
mode='train',
|
||||
src_dict_size=-1,
|
||||
trg_dict_size=-1,
|
||||
lang='en',
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'test', 'val'], \
|
||||
"mode should be 'train', 'test' or 'val', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file is not set and downloading automatically is disabled"
|
||||
self.data_file = _check_exists_and_download(
|
||||
data_file, DATA_URL, DATA_MD5, 'wmt16', download)
|
||||
|
||||
self.lang = lang
|
||||
assert src_dict_size > 0, "dict_size should be set as positive number"
|
||||
assert trg_dict_size > 0, "dict_size should be set as positive number"
|
||||
self.src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if lang == "en"
|
||||
else TOTAL_DE_WORDS))
|
||||
self.trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if lang == "en"
|
||||
else TOTAL_EN_WORDS))
|
||||
|
||||
# load source and target word dict
|
||||
self.src_dict = self._load_dict(lang, src_dict_size)
|
||||
self.trg_dict = self._load_dict("de" if lang == "en" else "en",
|
||||
trg_dict_size)
|
||||
|
||||
# load data
|
||||
self.data = self._load_data()
|
||||
|
||||
def _load_dict(self, lang, dict_size, reverse=False):
|
||||
dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
|
||||
"wmt16/%s_%d.dict" % (lang, dict_size))
|
||||
if not os.path.exists(dict_path) or (
|
||||
len(open(dict_path, "rb").readlines()) != dict_size):
|
||||
self._build_dict(dict_path, dict_size, lang)
|
||||
|
||||
word_dict = {}
|
||||
with open(dict_path, "rb") as fdict:
|
||||
for idx, line in enumerate(fdict):
|
||||
if reverse:
|
||||
word_dict[idx] = cpt.to_text(line.strip())
|
||||
else:
|
||||
word_dict[cpt.to_text(line.strip())] = idx
|
||||
return word_dict
|
||||
|
||||
def _build_dict(self, dict_path, dict_size, lang):
|
||||
word_dict = defaultdict(int)
|
||||
with tarfile.open(self.data_file, mode="r") as f:
|
||||
for line in f.extractfile("wmt16/train"):
|
||||
line = cpt.to_text(line)
|
||||
line_split = line.strip().split("\t")
|
||||
if len(line_split) != 2: continue
|
||||
sen = line_split[0] if self.lang == "en" else line_split[1]
|
||||
for w in sen.split():
|
||||
word_dict[w] += 1
|
||||
|
||||
with open(dict_path, "wb") as fout:
|
||||
fout.write(
|
||||
cpt.to_bytes("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)))
|
||||
for idx, word in enumerate(
|
||||
sorted(
|
||||
six.iteritems(word_dict),
|
||||
key=lambda x: x[1],
|
||||
reverse=True)):
|
||||
if idx + 3 == dict_size: break
|
||||
fout.write(cpt.to_bytes(word[0]))
|
||||
fout.write(cpt.to_bytes('\n'))
|
||||
|
||||
def _load_data(self):
|
||||
# the index for start mark, end mark, and unk are the same in source
|
||||
# language and target language. Here uses the source language
|
||||
# dictionary to determine their indices.
|
||||
start_id = self.src_dict[START_MARK]
|
||||
end_id = self.src_dict[END_MARK]
|
||||
unk_id = self.src_dict[UNK_MARK]
|
||||
|
||||
src_col = 0 if self.lang == "en" else 1
|
||||
trg_col = 1 - src_col
|
||||
|
||||
self.src_ids = []
|
||||
self.trg_ids = []
|
||||
self.trg_ids_next = []
|
||||
with tarfile.open(self.data_file, mode="r") as f:
|
||||
for line in f.extractfile("wmt16/{}".format(self.mode)):
|
||||
line = cpt.to_text(line)
|
||||
line_split = line.strip().split("\t")
|
||||
if len(line_split) != 2:
|
||||
continue
|
||||
src_words = line_split[src_col].split()
|
||||
src_ids = [start_id] + [
|
||||
self.src_dict.get(w, unk_id) for w in src_words
|
||||
] + [end_id]
|
||||
|
||||
trg_words = line_split[trg_col].split()
|
||||
trg_ids = [self.trg_dict.get(w, unk_id) for w in trg_words]
|
||||
|
||||
trg_ids_next = trg_ids + [end_id]
|
||||
trg_ids = [start_id] + trg_ids
|
||||
|
||||
self.src_ids.append(src_ids)
|
||||
self.trg_ids.append(trg_ids)
|
||||
self.trg_ids_next.append(trg_ids_next)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return (np.array(self.src_ids[idx]), np.array(self.trg_ids[idx]),
|
||||
np.array(self.trg_ids_next[idx]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_ids)
|
||||
|
||||
def get_dict(self, lang, reverse=False):
|
||||
"""
|
||||
return the word dictionary for the specified language.
|
||||
|
||||
Args:
|
||||
lang(string): A string indicating which language is the source
|
||||
language. Available options are: "en" for English
|
||||
and "de" for Germany.
|
||||
reverse(bool): If reverse is set to False, the returned python
|
||||
dictionary will use word as key and use index as value.
|
||||
If reverse is set to True, the returned python
|
||||
dictionary will use index as key and word as value.
|
||||
|
||||
Returns:
|
||||
dict: The word dictionary for the specific language.
|
||||
"""
|
||||
|
||||
dict_size = self.src_dict_size if lang == self.lang else self.trg_dict_size
|
||||
|
||||
dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
|
||||
"wmt16/%s_%d.dict" % (lang, dict_size))
|
||||
assert os.path.exists(dict_path), "Word dictionary does not exist. "
|
||||
"Please invoke paddle.dataset.wmt16.train/test/validation first "
|
||||
"to build the dictionary."
|
||||
return _load_dict(lang, dict_size)
|
@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestCifar10Train(unittest.TestCase):
|
||||
def test_main(self):
|
||||
cifar = Cifar10(mode='train')
|
||||
self.assertTrue(len(cifar) == 50000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 50000)
|
||||
data, label = cifar[idx]
|
||||
self.assertTrue(len(data.shape) == 1)
|
||||
self.assertTrue(data.shape[0] == 3072)
|
||||
self.assertTrue(0 <= int(label) <= 9)
|
||||
|
||||
|
||||
class TestCifar10Test(unittest.TestCase):
|
||||
def test_main(self):
|
||||
cifar = Cifar10(mode='test')
|
||||
self.assertTrue(len(cifar) == 10000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 10000)
|
||||
data, label = cifar[idx]
|
||||
self.assertTrue(len(data.shape) == 1)
|
||||
self.assertTrue(data.shape[0] == 3072)
|
||||
self.assertTrue(0 <= int(label) <= 9)
|
||||
|
||||
|
||||
class TestCifar100Train(unittest.TestCase):
|
||||
def test_main(self):
|
||||
cifar = Cifar100(mode='train')
|
||||
self.assertTrue(len(cifar) == 50000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 50000)
|
||||
data, label = cifar[idx]
|
||||
self.assertTrue(len(data.shape) == 1)
|
||||
self.assertTrue(data.shape[0] == 3072)
|
||||
self.assertTrue(0 <= int(label) <= 99)
|
||||
|
||||
|
||||
class TestCifar100Test(unittest.TestCase):
|
||||
def test_main(self):
|
||||
cifar = Cifar100(mode='test')
|
||||
self.assertTrue(len(cifar) == 10000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 10000)
|
||||
data, label = cifar[idx]
|
||||
self.assertTrue(len(data.shape) == 1)
|
||||
self.assertTrue(data.shape[0] == 3072)
|
||||
self.assertTrue(0 <= int(label) <= 99)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestConll05st(unittest.TestCase):
|
||||
def test_main(self):
|
||||
conll05st = Conll05st()
|
||||
self.assertTrue(len(conll05st) == 5267)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 5267)
|
||||
sample = conll05st[idx]
|
||||
self.assertTrue(len(sample) == 9)
|
||||
for s in sample:
|
||||
self.assertTrue(len(s.shape) == 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestImdbTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
imdb = Imdb(mode='train')
|
||||
self.assertTrue(len(imdb) == 25000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 25000)
|
||||
data, label = imdb[idx]
|
||||
self.assertTrue(len(data.shape) == 1)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
self.assertTrue(int(label) in [0, 1])
|
||||
|
||||
|
||||
class TestImdbTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
imdb = Imdb(mode='test')
|
||||
self.assertTrue(len(imdb) == 25000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 25000)
|
||||
data, label = imdb[idx]
|
||||
self.assertTrue(len(data.shape) == 1)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
self.assertTrue(int(label) in [0, 1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestImikolovTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
imikolov = Imikolov(mode='train', data_type='NGRAM', window_size=2)
|
||||
self.assertTrue(len(imikolov) == 929589)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 929589)
|
||||
data = imikolov[idx]
|
||||
self.assertTrue(len(data) == 2)
|
||||
|
||||
|
||||
class TestImikolovTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
imikolov = Imikolov(mode='test', data_type='NGRAM', window_size=2)
|
||||
self.assertTrue(len(imikolov) == 82430)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 82430)
|
||||
data = imikolov[idx]
|
||||
self.assertTrue(len(data) == 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestMovieReviewsTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
movie_reviews = MovieReviews(mode='train')
|
||||
self.assertTrue(len(movie_reviews) == 1600)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1600)
|
||||
data = movie_reviews[idx]
|
||||
self.assertTrue(len(data) == 2)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(int(data[1]) in [0, 1])
|
||||
|
||||
|
||||
class TestMovieReviewsTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
movie_reviews = MovieReviews(mode='test')
|
||||
self.assertTrue(len(movie_reviews) == 400)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 400)
|
||||
data = movie_reviews[idx]
|
||||
self.assertTrue(len(data) == 2)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(int(data[1]) in [0, 1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,61 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestMovielensTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
movielens = Movielens(mode='train')
|
||||
# movielens dataset random split train/test
|
||||
# not check dataset length here
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 900000)
|
||||
data = movielens[idx]
|
||||
self.assertTrue(len(data) == 8)
|
||||
for i, d in enumerate(data):
|
||||
self.assertTrue(len(d.shape) == 1)
|
||||
if i not in [5, 6]:
|
||||
self.assertTrue(d.shape[0] == 1)
|
||||
|
||||
|
||||
class TestMovielensTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
movielens = Movielens(mode='test')
|
||||
# movielens dataset random split train/test
|
||||
# not check dataset length here
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 100000)
|
||||
data = movielens[idx]
|
||||
self.assertTrue(len(data) == 8)
|
||||
for i, d in enumerate(data):
|
||||
self.assertTrue(len(d.shape) == 1)
|
||||
if i not in [5, 6]:
|
||||
self.assertTrue(d.shape[0] == 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestUCIHousingTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
uci_housing = UCIHousing(mode='train')
|
||||
self.assertTrue(len(uci_housing) == 404)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 404)
|
||||
data = uci_housing[idx]
|
||||
self.assertTrue(len(data) == 2)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(data[0].shape[0] == 13)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(data[1].shape[0] == 1)
|
||||
|
||||
|
||||
class TestUCIHousingTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
uci_housing = UCIHousing(mode='test')
|
||||
self.assertTrue(len(uci_housing) == 102)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 102)
|
||||
data = uci_housing[idx]
|
||||
self.assertTrue(len(data) == 2)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(data[0].shape[0] == 13)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(data[1].shape[0] == 1)
|
||||
|
||||
|
||||
class TestWMT14Train(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt14 = WMT14(mode='train', dict_size=50)
|
||||
self.assertTrue(len(wmt14) == 191155)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 191155)
|
||||
data = wmt14[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT14Test(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt14 = WMT14(mode='test', dict_size=50)
|
||||
self.assertTrue(len(wmt14) == 5957)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 5957)
|
||||
data = wmt14[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT14Gen(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt14 = WMT14(mode='gen', dict_size=50)
|
||||
self.assertTrue(len(wmt14) == 3001)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 3001)
|
||||
data = wmt14[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import voc2012, VOC2012
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
# VOC2012 is too large for unittest to download, stub a small dataset here
|
||||
voc2012.VOC_URL = 'https://paddlemodels.bj.bcebos.com/voc2012_stub/VOCtrainval_11-May-2012.tar'
|
||||
voc2012.VOC_MD5 = '34cb1fe5bdc139a5454b25b16118fff8'
|
||||
|
||||
|
||||
class TestVOC2012Train(unittest.TestCase):
|
||||
def test_main(self):
|
||||
voc2012 = VOC2012(mode='train')
|
||||
self.assertTrue(len(voc2012) == 3)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 3)
|
||||
image, label = voc2012[idx]
|
||||
self.assertTrue(len(image.shape) == 3)
|
||||
self.assertTrue(len(label.shape) == 2)
|
||||
|
||||
|
||||
class TestVOC2012Valid(unittest.TestCase):
|
||||
def test_main(self):
|
||||
voc2012 = VOC2012(mode='valid')
|
||||
self.assertTrue(len(voc2012) == 1)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1)
|
||||
image, label = voc2012[idx]
|
||||
self.assertTrue(len(image.shape) == 3)
|
||||
self.assertTrue(len(label.shape) == 2)
|
||||
|
||||
|
||||
class TestVOC2012Test(unittest.TestCase):
|
||||
def test_main(self):
|
||||
voc2012 = VOC2012(mode='test')
|
||||
self.assertTrue(len(voc2012) == 2)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1)
|
||||
image, label = voc2012[idx]
|
||||
self.assertTrue(len(image.shape) == 3)
|
||||
self.assertTrue(len(label.shape) == 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,119 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestWMT14Train(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt14 = WMT14(mode='train', dict_size=50)
|
||||
self.assertTrue(len(wmt14) == 191155)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 191155)
|
||||
data = wmt14[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT14Test(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt14 = WMT14(mode='test', dict_size=50)
|
||||
self.assertTrue(len(wmt14) == 5957)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 5957)
|
||||
data = wmt14[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT14Gen(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt14 = WMT14(mode='gen', dict_size=50)
|
||||
self.assertTrue(len(wmt14) == 3001)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 3001)
|
||||
data = wmt14[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT16Train(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt16 = WMT16(
|
||||
mode='train', src_dict_size=50, trg_dict_size=50, lang='en')
|
||||
self.assertTrue(len(wmt16) == 29000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 29000)
|
||||
data = wmt16[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT16Test(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt16 = WMT16(
|
||||
mode='test', src_dict_size=50, trg_dict_size=50, lang='en')
|
||||
self.assertTrue(len(wmt16) == 1000)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1000)
|
||||
data = wmt16[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
class TestWMT16Val(unittest.TestCase):
|
||||
def test_main(self):
|
||||
wmt16 = WMT16(mode='val', src_dict_size=50, trg_dict_size=50, lang='en')
|
||||
self.assertTrue(len(wmt16) == 1014)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1014)
|
||||
data = wmt16[idx]
|
||||
self.assertTrue(len(data) == 3)
|
||||
self.assertTrue(len(data[0].shape) == 1)
|
||||
self.assertTrue(len(data[1].shape) == 1)
|
||||
self.assertTrue(len(data[2].shape) == 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue