Make paddle.fluid no longer depends on paddle.v2

In this way we can build and test using WITH_FLUID_ONLY flag being set
to ON.

- move paddle.v2.dataset,reader to paddle.dataset,reader
- remove unused code (which depends on v2) in paddle.dataset,reader
helinwang-patch-1
Helin Wang 7 years ago
parent f5aa42379f
commit bcf7c36b0b

@ -73,12 +73,13 @@ add_custom_target(paddle_python ALL DEPENDS ${paddle_python_deps})
set(PADDLE_PYTHON_PACKAGE_DIR ${CMAKE_CURRENT_BINARY_DIR}/dist/) set(PADDLE_PYTHON_PACKAGE_DIR ${CMAKE_CURRENT_BINARY_DIR}/dist/)
if (WITH_TESTING) if (WITH_TESTING)
add_subdirectory(paddle/reader/tests)
add_subdirectory(paddle/dataset/tests)
if(NOT WITH_FLUID_ONLY) if(NOT WITH_FLUID_ONLY)
add_subdirectory(paddle/trainer_config_helpers/tests) add_subdirectory(paddle/trainer_config_helpers/tests)
if (WITH_SWIG_PY) if (WITH_SWIG_PY)
# enable v2 API unittest only when paddle swig api is compiled # enable v2 API unittest only when paddle swig api is compiled
add_subdirectory(paddle/v2/tests) add_subdirectory(paddle/v2/tests)
add_subdirectory(paddle/v2/reader/tests)
add_subdirectory(paddle/v2/plot/tests) add_subdirectory(paddle/v2/plot/tests)
endif() endif()
endif() endif()

@ -14,8 +14,14 @@
try: try:
from version import full_version as __version__ from version import full_version as __version__
from version import commit as __git_commit__ from version import commit as __git_commit__
except ImportError: except ImportError:
import sys import sys
sys.stderr.write('''Warning with import paddle: you should not sys.stderr.write('''Warning with import paddle: you should not
import paddle from the source directory; please install paddlepaddle*.whl firstly.''' import paddle from the source directory; please install paddlepaddle*.whl firstly.'''
) )
import reader
import dataset
import batch
batch = batch.batch

@ -28,6 +28,7 @@ import wmt16
import mq2007 import mq2007
import flowers import flowers
import voc2012 import voc2012
import image
__all__ = [ __all__ = [
'mnist', 'mnist',
@ -43,4 +44,5 @@ __all__ = [
'mq2007', 'mq2007',
'flowers', 'flowers',
'voc2012', 'voc2012',
'image',
] ]

@ -31,7 +31,7 @@ images per class.
import cPickle import cPickle
import itertools import itertools
import numpy import numpy
import paddle.v2.dataset.common import paddle.dataset.common
import tarfile import tarfile
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert'] __all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
@ -75,7 +75,7 @@ def train100():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), paddle.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'train') 'train')
@ -90,7 +90,7 @@ def test100():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5), paddle.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5),
'test') 'test')
@ -105,7 +105,7 @@ def train10():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch') 'data_batch')
@ -120,20 +120,20 @@ def test10():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch') 'test_batch')
def fetch(): def fetch():
paddle.v2.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5) paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5)
paddle.v2.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5) paddle.dataset.common.download(CIFAR100_URL, 'cifar', CIFAR100_MD5)
def convert(path): def convert(path):
""" """
Converts dataset to recordio format Converts dataset to recordio format
""" """
paddle.v2.dataset.common.convert(path, train100(), 1000, "cifar_train100") paddle.dataset.common.convert(path, train100(), 1000, "cifar_train100")
paddle.v2.dataset.common.convert(path, test100(), 1000, "cifar_test100") paddle.dataset.common.convert(path, test100(), 1000, "cifar_test100")
paddle.v2.dataset.common.convert(path, train10(), 1000, "cifar_train10") paddle.dataset.common.convert(path, train10(), 1000, "cifar_train10")
paddle.v2.dataset.common.convert(path, test10(), 1000, "cifar_test10") paddle.dataset.common.convert(path, test10(), 1000, "cifar_test10")

@ -19,7 +19,7 @@ import errno
import shutil import shutil
import sys import sys
import importlib import importlib
import paddle.v2.dataset import paddle.dataset
import cPickle import cPickle
import glob import glob
import cPickle as pickle import cPickle as pickle
@ -105,24 +105,24 @@ def download(url, module_name, md5sum, save_name=None):
def fetch_all(): def fetch_all():
for module_name in filter(lambda x: not x.startswith("__"), for module_name in filter(lambda x: not x.startswith("__"),
dir(paddle.v2.dataset)): dir(paddle.dataset)):
if "fetch" in dir( if "fetch" in dir(
importlib.import_module("paddle.v2.dataset.%s" % module_name)): importlib.import_module("paddle.dataset.%s" % module_name)):
getattr( getattr(
importlib.import_module("paddle.v2.dataset.%s" % module_name), importlib.import_module("paddle.dataset.%s" % module_name),
"fetch")() "fetch")()
def fetch_all_recordio(path): def fetch_all_recordio(path):
for module_name in filter(lambda x: not x.startswith("__"), for module_name in filter(lambda x: not x.startswith("__"),
dir(paddle.v2.dataset)): dir(paddle.dataset)):
if "convert" in dir( if "convert" in dir(
importlib.import_module("paddle.v2.dataset.%s" % module_name)) and \ importlib.import_module("paddle.dataset.%s" % module_name)) and \
not module_name == "common": not module_name == "common":
ds_path = os.path.join(path, module_name) ds_path = os.path.join(path, module_name)
must_mkdirs(ds_path) must_mkdirs(ds_path)
getattr( getattr(
importlib.import_module("paddle.v2.dataset.%s" % module_name), importlib.import_module("paddle.dataset.%s" % module_name),
"convert")(ds_path) "convert")(ds_path)
@ -130,7 +130,7 @@ def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
""" """
you can call the function as: you can call the function as:
split(paddle.v2.dataset.cifar.train10(), line_count=1000, split(paddle.dataset.cifar.train10(), line_count=1000,
suffix="imikolov-train-%05d.pickle") suffix="imikolov-train-%05d.pickle")
the output files as: the output files as:

@ -23,7 +23,7 @@ to initialize SRL model.
import tarfile import tarfile
import gzip import gzip
import itertools import itertools
import paddle.v2.dataset.common import paddle.dataset.common
__all__ = ['test, get_dict', 'get_embedding', 'convert'] __all__ = ['test, get_dict', 'get_embedding', 'convert']
@ -203,14 +203,11 @@ def get_dict():
Get the word, verb and label dictionary of Wikipedia corpus. Get the word, verb and label dictionary of Wikipedia corpus.
""" """
word_dict = load_dict( word_dict = load_dict(
paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st', paddle.dataset.common.download(WORDDICT_URL, 'conll05st', WORDDICT_MD5))
WORDDICT_MD5))
verb_dict = load_dict( verb_dict = load_dict(
paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st', paddle.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5))
VERBDICT_MD5))
label_dict = load_label_dict( label_dict = load_label_dict(
paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st', paddle.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5))
TRGDICT_MD5))
return word_dict, verb_dict, label_dict return word_dict, verb_dict, label_dict
@ -218,7 +215,7 @@ def get_embedding():
""" """
Get the trained word vector based on Wikipedia corpus. Get the trained word vector based on Wikipedia corpus.
""" """
return paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5) return paddle.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5)
def test(): def test():
@ -235,23 +232,23 @@ def test():
""" """
word_dict, verb_dict, label_dict = get_dict() word_dict, verb_dict, label_dict = get_dict()
reader = corpus_reader( reader = corpus_reader(
paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5), paddle.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5),
words_name='conll05st-release/test.wsj/words/test.wsj.words.gz', words_name='conll05st-release/test.wsj/words/test.wsj.words.gz',
props_name='conll05st-release/test.wsj/props/test.wsj.props.gz') props_name='conll05st-release/test.wsj/props/test.wsj.props.gz')
return reader_creator(reader, word_dict, verb_dict, label_dict) return reader_creator(reader, word_dict, verb_dict, label_dict)
def fetch(): def fetch():
paddle.v2.dataset.common.download(WORDDICT_URL, 'conll05st', WORDDICT_MD5) paddle.dataset.common.download(WORDDICT_URL, 'conll05st', WORDDICT_MD5)
paddle.v2.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5) paddle.dataset.common.download(VERBDICT_URL, 'conll05st', VERBDICT_MD5)
paddle.v2.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5) paddle.dataset.common.download(TRGDICT_URL, 'conll05st', TRGDICT_MD5)
paddle.v2.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5) paddle.dataset.common.download(EMB_URL, 'conll05st', EMB_MD5)
paddle.v2.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5) paddle.dataset.common.download(DATA_URL, 'conll05st', DATA_MD5)
def convert(path): def convert(path):
""" """
Converts dataset to recordio format Converts dataset to recordio format
""" """
paddle.v2.dataset.common.convert(path, test(), 1000, "conl105_train") paddle.dataset.common.convert(path, test(), 1000, "conl105_train")
paddle.v2.dataset.common.convert(path, test(), 1000, "conl105_test") paddle.dataset.common.convert(path, test(), 1000, "conl105_test")

@ -34,8 +34,8 @@ import functools
from common import download from common import download
import tarfile import tarfile
import scipy.io as scio import scipy.io as scio
from paddle.v2.image import * from paddle.dataset.image import *
from paddle.v2.reader import * from paddle.reader import *
import os import os
import numpy as np import numpy as np
from multiprocessing import cpu_count from multiprocessing import cpu_count

@ -20,7 +20,7 @@ of 25,000 highly polar movie reviews for training, and 25,000 for testing.
Besides, this module also provides API for building dictionary. Besides, this module also provides API for building dictionary.
""" """
import paddle.v2.dataset.common import paddle.dataset.common
import collections import collections
import tarfile import tarfile
import re import re
@ -37,8 +37,7 @@ def tokenize(pattern):
Read files that match the given pattern. Tokenize and yield each file. Read files that match the given pattern. Tokenize and yield each file.
""" """
with tarfile.open(paddle.v2.dataset.common.download(URL, 'imdb', with tarfile.open(paddle.dataset.common.download(URL, 'imdb', MD5)) as tarf:
MD5)) as tarf:
# Note that we should use tarfile.next(), which does # Note that we should use tarfile.next(), which does
# sequential access of member files, other than # sequential access of member files, other than
# tarfile.extractfile, which does random access and might # tarfile.extractfile, which does random access and might
@ -136,7 +135,7 @@ def word_dict():
def fetch(): def fetch():
paddle.v2.dataset.common.download(URL, 'imdb', MD5) paddle.dataset.common.download(URL, 'imdb', MD5)
def convert(path): def convert(path):
@ -144,5 +143,5 @@ def convert(path):
Converts dataset to recordio format Converts dataset to recordio format
""" """
w = word_dict() w = word_dict()
paddle.v2.dataset.common.convert(path, lambda: train(w), 1000, "imdb_train") paddle.dataset.common.convert(path, lambda: train(w), 1000, "imdb_train")
paddle.v2.dataset.common.convert(path, lambda: test(w), 1000, "imdb_test") paddle.dataset.common.convert(path, lambda: test(w), 1000, "imdb_test")

@ -18,7 +18,7 @@ This module will download dataset from
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators. into paddle reader creators.
""" """
import paddle.v2.dataset.common import paddle.dataset.common
import collections import collections
import tarfile import tarfile
@ -54,9 +54,9 @@ def build_dict(min_word_freq=50):
train_filename = './simple-examples/data/ptb.train.txt' train_filename = './simple-examples/data/ptb.train.txt'
test_filename = './simple-examples/data/ptb.valid.txt' test_filename = './simple-examples/data/ptb.valid.txt'
with tarfile.open( with tarfile.open(
paddle.v2.dataset.common.download( paddle.dataset.common.download(paddle.dataset.imikolov.URL,
paddle.v2.dataset.imikolov.URL, 'imikolov', 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf: paddle.dataset.imikolov.MD5)) as tf:
trainf = tf.extractfile(train_filename) trainf = tf.extractfile(train_filename)
testf = tf.extractfile(test_filename) testf = tf.extractfile(test_filename)
word_freq = word_count(testf, word_count(trainf)) word_freq = word_count(testf, word_count(trainf))
@ -77,9 +77,9 @@ def build_dict(min_word_freq=50):
def reader_creator(filename, word_idx, n, data_type): def reader_creator(filename, word_idx, n, data_type):
def reader(): def reader():
with tarfile.open( with tarfile.open(
paddle.v2.dataset.common.download( paddle.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov', paddle.dataset.imikolov.URL, 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf: paddle.dataset.imikolov.MD5)) as tf:
f = tf.extractfile(filename) f = tf.extractfile(filename)
UNK = word_idx['<unk>'] UNK = word_idx['<unk>']
@ -145,7 +145,7 @@ def test(word_idx, n, data_type=DataType.NGRAM):
def fetch(): def fetch():
paddle.v2.dataset.common.download(URL, "imikolov", MD5) paddle.dataset.common.download(URL, "imikolov", MD5)
def convert(path): def convert(path):
@ -154,8 +154,7 @@ def convert(path):
""" """
N = 5 N = 5
word_dict = build_dict() word_dict = build_dict()
paddle.v2.dataset.common.convert(path, paddle.dataset.common.convert(path,
train(word_dict, N), 1000, train(word_dict, N), 1000, "imikolov_train")
"imikolov_train") paddle.dataset.common.convert(path,
paddle.v2.dataset.common.convert(path, test(word_dict, N), 1000, "imikolov_test")
test(word_dict, N), 1000, "imikolov_test")

@ -17,7 +17,7 @@ MNIST dataset.
This module will download dataset from http://yann.lecun.com/exdb/mnist/ and This module will download dataset from http://yann.lecun.com/exdb/mnist/ and
parse training set and test set into paddle reader creators. parse training set and test set into paddle reader creators.
""" """
import paddle.v2.dataset.common import paddle.dataset.common
import subprocess import subprocess
import numpy import numpy
import platform import platform
@ -85,10 +85,10 @@ def train():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', paddle.dataset.common.download(TRAIN_IMAGE_URL, 'mnist',
TRAIN_IMAGE_MD5), TRAIN_IMAGE_MD5),
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', paddle.dataset.common.download(TRAIN_LABEL_URL, 'mnist',
TRAIN_LABEL_MD5), 100) TRAIN_LABEL_MD5), 100)
def test(): def test():
@ -102,22 +102,21 @@ def test():
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', paddle.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5),
TEST_IMAGE_MD5), paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TEST_LABEL_MD5),
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', 100)
TEST_LABEL_MD5), 100)
def fetch(): def fetch():
paddle.v2.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5) paddle.dataset.common.download(TRAIN_IMAGE_URL, 'mnist', TRAIN_IMAGE_MD5)
paddle.v2.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) paddle.dataset.common.download(TRAIN_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
paddle.v2.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5) paddle.dataset.common.download(TEST_IMAGE_URL, 'mnist', TEST_IMAGE_MD5)
paddle.v2.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5) paddle.dataset.common.download(TEST_LABEL_URL, 'mnist', TRAIN_LABEL_MD5)
def convert(path): def convert(path):
""" """
Converts dataset to recordio format Converts dataset to recordio format
""" """
paddle.v2.dataset.common.convert(path, train(), 1000, "minist_train") paddle.dataset.common.convert(path, train(), 1000, "minist_train")
paddle.v2.dataset.common.convert(path, test(), 1000, "minist_test") paddle.dataset.common.convert(path, test(), 1000, "minist_test")

@ -23,7 +23,7 @@ set and test set into paddle reader creators.
""" """
import zipfile import zipfile
import paddle.v2.dataset.common import paddle.dataset.common
import re import re
import random import random
import functools import functools
@ -100,7 +100,7 @@ USER_INFO = None
def __initialize_meta_info__(): def __initialize_meta_info__():
fn = paddle.v2.dataset.common.download(URL, "movielens", MD5) fn = paddle.dataset.common.download(URL, "movielens", MD5)
global MOVIE_INFO global MOVIE_INFO
if MOVIE_INFO is None: if MOVIE_INFO is None:
pattern = re.compile(r'^(.*)\((\d+)\)$') pattern = re.compile(r'^(.*)\((\d+)\)$')
@ -247,15 +247,15 @@ def unittest():
def fetch(): def fetch():
paddle.v2.dataset.common.download(URL, "movielens", MD5) paddle.dataset.common.download(URL, "movielens", MD5)
def convert(path): def convert(path):
""" """
Converts dataset to recordio format Converts dataset to recordio format
""" """
paddle.v2.dataset.common.convert(path, train(), 1000, "movielens_train") paddle.dataset.common.convert(path, train(), 1000, "movielens_train")
paddle.v2.dataset.common.convert(path, test(), 1000, "movielens_test") paddle.dataset.common.convert(path, test(), 1000, "movielens_test")
if __name__ == '__main__': if __name__ == '__main__':

@ -26,7 +26,7 @@ from itertools import chain
import nltk import nltk
from nltk.corpus import movie_reviews from nltk.corpus import movie_reviews
import paddle.v2.dataset.common import paddle.dataset.common
__all__ = ['train', 'test', 'get_word_dict', 'convert'] __all__ = ['train', 'test', 'get_word_dict', 'convert']
NUM_TRAINING_INSTANCES = 1600 NUM_TRAINING_INSTANCES = 1600
@ -39,13 +39,13 @@ def download_data_if_not_yet():
""" """
try: try:
# make sure that nltk can find the data # make sure that nltk can find the data
if paddle.v2.dataset.common.DATA_HOME not in nltk.data.path: if paddle.dataset.common.DATA_HOME not in nltk.data.path:
nltk.data.path.append(paddle.v2.dataset.common.DATA_HOME) nltk.data.path.append(paddle.dataset.common.DATA_HOME)
movie_reviews.categories() movie_reviews.categories()
except LookupError: except LookupError:
print "Downloading movie_reviews data set, please wait....." print "Downloading movie_reviews data set, please wait....."
nltk.download( nltk.download(
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME) 'movie_reviews', download_dir=paddle.dataset.common.DATA_HOME)
print "Download data set success....." print "Download data set success....."
print "Path is " + nltk.data.find('corpora/movie_reviews').path print "Path is " + nltk.data.find('corpora/movie_reviews').path
@ -129,13 +129,12 @@ def test():
def fetch(): def fetch():
nltk.download( nltk.download('movie_reviews', download_dir=paddle.dataset.common.DATA_HOME)
'movie_reviews', download_dir=paddle.v2.dataset.common.DATA_HOME)
def convert(path): def convert(path):
""" """
Converts dataset to recordio format Converts dataset to recordio format
""" """
paddle.v2.dataset.common.convert(path, train, 1000, "sentiment_train") paddle.dataset.common.convert(path, train, 1000, "sentiment_train")
paddle.v2.dataset.common.convert(path, test, 1000, "sentiment_test") paddle.dataset.common.convert(path, test, 1000, "sentiment_test")

@ -0,0 +1 @@
py_test(test_image SRCS test_image.py)

Before

Width:  |  Height:  |  Size: 56 KiB

After

Width:  |  Height:  |  Size: 56 KiB

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.cifar import paddle.dataset.cifar
import unittest import unittest
@ -29,25 +29,25 @@ class TestCIFAR(unittest.TestCase):
def test_test10(self): def test_test10(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test10()) paddle.dataset.cifar.test10())
self.assertEqual(instances, 10000) self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9) self.assertEqual(max_label_value, 9)
def test_train10(self): def test_train10(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train10()) paddle.dataset.cifar.train10())
self.assertEqual(instances, 50000) self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 9) self.assertEqual(max_label_value, 9)
def test_test100(self): def test_test100(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.test100()) paddle.dataset.cifar.test100())
self.assertEqual(instances, 10000) self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 99) self.assertEqual(max_label_value, 99)
def test_train100(self): def test_train100(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.cifar.train100()) paddle.dataset.cifar.train100())
self.assertEqual(instances, 50000) self.assertEqual(instances, 50000)
self.assertEqual(max_label_value, 99) self.assertEqual(max_label_value, 99)

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.common import paddle.dataset.common
import unittest import unittest
import tempfile import tempfile
import glob import glob
@ -24,14 +24,14 @@ class TestCommon(unittest.TestCase):
with open(temp_path, 'w') as f: with open(temp_path, 'w') as f:
f.write("Hello\n") f.write("Hello\n")
self.assertEqual('09f7e02f1290be211da707a266f153b3', self.assertEqual('09f7e02f1290be211da707a266f153b3',
paddle.v2.dataset.common.md5file(temp_path)) paddle.dataset.common.md5file(temp_path))
def test_download(self): def test_download(self):
yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460' yi_avatar = 'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
self.assertEqual( self.assertEqual(
paddle.v2.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460', paddle.dataset.common.DATA_HOME + '/test/1548775?v=3&s=460',
paddle.v2.dataset.common.download( paddle.dataset.common.download(yi_avatar, 'test',
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d')) 'f75287202d6622414c706c36c16f8e0d'))
def test_split(self): def test_split(self):
def test_reader(): def test_reader():
@ -42,7 +42,7 @@ class TestCommon(unittest.TestCase):
return reader return reader
_, temp_path = tempfile.mkstemp() _, temp_path = tempfile.mkstemp()
paddle.v2.dataset.common.split( paddle.dataset.common.split(
test_reader(), 4, suffix=temp_path + '/test-%05d.pickle') test_reader(), 4, suffix=temp_path + '/test-%05d.pickle')
files = glob.glob(temp_path + '/test-%05d.pickle') files = glob.glob(temp_path + '/test-%05d.pickle')
self.assertEqual(len(files), 3) self.assertEqual(len(files), 3)
@ -52,7 +52,7 @@ class TestCommon(unittest.TestCase):
for x in xrange(5): for x in xrange(5):
with open(temp_path + '/%05d.test' % x) as f: with open(temp_path + '/%05d.test' % x) as f:
f.write('%d\n' % x) f.write('%d\n' % x)
reader = paddle.v2.dataset.common.cluster_files_reader( reader = paddle.dataset.common.cluster_files_reader(
temp_path + '/*.test', 5, 0) temp_path + '/*.test', 5, 0)
for idx, e in enumerate(reader()): for idx, e in enumerate(reader()):
self.assertEqual(e, str("0")) self.assertEqual(e, str("0"))
@ -69,9 +69,9 @@ class TestCommon(unittest.TestCase):
return reader return reader
path = tempfile.mkdtemp() path = tempfile.mkdtemp()
paddle.v2.dataset.common.convert(path, paddle.dataset.common.convert(path,
test_reader(), num_shards, test_reader(), num_shards,
'random_images') 'random_images')
files = glob.glob(path + '/random_images-*') files = glob.glob(path + '/random_images-*')
self.assertEqual(len(files), num_shards) self.assertEqual(len(files), num_shards)

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.flowers import paddle.dataset.flowers
import unittest import unittest
@ -30,19 +30,19 @@ class TestFlowers(unittest.TestCase):
def test_train(self): def test_train(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train()) paddle.dataset.flowers.train())
self.assertEqual(instances, 6149) self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_test(self): def test_test(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test()) paddle.dataset.flowers.test())
self.assertEqual(instances, 1020) self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_valid(self): def test_valid(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.valid()) paddle.dataset.flowers.valid())
self.assertEqual(instances, 1020) self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.imdb import paddle.dataset.imdb
import unittest import unittest
import re import re
@ -30,15 +30,13 @@ class TestIMDB(unittest.TestCase):
def test_build_dict(self): def test_build_dict(self):
if self.word_idx == None: if self.word_idx == None:
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN, self.word_idx = paddle.dataset.imdb.build_dict(TRAIN_PATTERN, 150)
150)
self.assertEqual(len(self.word_idx), 7036) self.assertEqual(len(self.word_idx), 7036)
def check_dataset(self, dataset, expected_size): def check_dataset(self, dataset, expected_size):
if self.word_idx == None: if self.word_idx == None:
self.word_idx = paddle.v2.dataset.imdb.build_dict(TRAIN_PATTERN, self.word_idx = paddle.dataset.imdb.build_dict(TRAIN_PATTERN, 150)
150)
sum = 0 sum = 0
for l in dataset(self.word_idx): for l in dataset(self.word_idx):
@ -47,10 +45,10 @@ class TestIMDB(unittest.TestCase):
self.assertEqual(sum, expected_size) self.assertEqual(sum, expected_size)
def test_train(self): def test_train(self):
self.check_dataset(paddle.v2.dataset.imdb.train, 25000) self.check_dataset(paddle.dataset.imdb.train, 25000)
def test_test(self): def test_test(self):
self.check_dataset(paddle.v2.dataset.imdb.test, 25000) self.check_dataset(paddle.dataset.imdb.test, 25000)
if __name__ == '__main__': if __name__ == '__main__':

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.imikolov import paddle.dataset.imikolov
import unittest import unittest
WORD_DICT = paddle.v2.dataset.imikolov.build_dict() WORD_DICT = paddle.dataset.imikolov.build_dict()
class TestMikolov(unittest.TestCase): class TestMikolov(unittest.TestCase):
@ -25,7 +25,7 @@ class TestMikolov(unittest.TestCase):
def test_train(self): def test_train(self):
n = 5 n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n) self.check_reader(paddle.dataset.imikolov.train(WORD_DICT, n), n)
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\ first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\ 'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
@ -34,16 +34,16 @@ class TestMikolov(unittest.TestCase):
WORD_DICT.get(ch, WORD_DICT['<unk>']) WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ') for ch in first_line.split(' ')
] ]
for l in paddle.v2.dataset.imikolov.train( for l in paddle.dataset.imikolov.train(
WORD_DICT, n=-1, WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)(): data_type=paddle.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:] read_line = l[0][1:]
break break
self.assertEqual(first_line, read_line) self.assertEqual(first_line, read_line)
def test_test(self): def test_test(self):
n = 5 n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n) self.check_reader(paddle.dataset.imikolov.test(WORD_DICT, n), n)
first_line = 'consumers may want to move their telephones a little '\ first_line = 'consumers may want to move their telephones a little '\
'closer to the tv set' 'closer to the tv set'
@ -51,9 +51,9 @@ class TestMikolov(unittest.TestCase):
WORD_DICT.get(ch, WORD_DICT['<unk>']) WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ') for ch in first_line.split(' ')
] ]
for l in paddle.v2.dataset.imikolov.test( for l in paddle.dataset.imikolov.test(
WORD_DICT, n=-1, WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)(): data_type=paddle.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:] read_line = l[0][1:]
break break
self.assertEqual(first_line, read_line) self.assertEqual(first_line, read_line)

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.mnist import paddle.dataset.mnist
import unittest import unittest
@ -29,13 +29,13 @@ class TestMNIST(unittest.TestCase):
def test_train(self): def test_train(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.mnist.train()) paddle.dataset.mnist.train())
self.assertEqual(instances, 60000) self.assertEqual(instances, 60000)
self.assertEqual(max_label_value, 9) self.assertEqual(max_label_value, 9)
def test_test(self): def test_test(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.mnist.test()) paddle.dataset.mnist.test())
self.assertEqual(instances, 10000) self.assertEqual(instances, 10000)
self.assertEqual(max_label_value, 9) self.assertEqual(max_label_value, 9)

@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.v2.dataset.mq2007 import paddle.dataset.mq2007
import unittest import unittest
class TestMQ2007(unittest.TestCase): class TestMQ2007(unittest.TestCase):
def test_pairwise(self): def test_pairwise(self):
for label, query_left, query_right in paddle.v2.dataset.mq2007.test( for label, query_left, query_right in paddle.dataset.mq2007.test(
format="pairwise"): format="pairwise"):
self.assertEqual(query_left.shape(), (46, )) self.assertEqual(query_left.shape(), (46, ))
self.assertEqual(query_right.shape(), (46, )) self.assertEqual(query_right.shape(), (46, ))
def test_listwise(self): def test_listwise(self):
for label_array, query_array in paddle.v2.dataset.mq2007.test( for label_array, query_array in paddle.dataset.mq2007.test(
format="listwise"): format="listwise"):
self.assertEqual(len(label_array), len(query_array)) self.assertEqual(len(label_array), len(query_array))

@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle.v2.image as image import paddle.dataset.image as image
class Image(unittest.TestCase): class Image(unittest.TestCase):

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save