add wmt16 into dataset.

add_depthwiseConv_op_gpu
ying 7 years ago
parent 38c61053ff
commit 9a97c7f745

@ -24,11 +24,23 @@ import conll05
import uci_housing
import sentiment
import wmt14
import wmt16
import mq2007
import flowers
import voc2012
__all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012'
'mnist',
'imikolov',
'imdb',
'cifar',
'movielens',
'conll05',
'sentiment'
'uci_housing',
'wmt14',
'wmt16',
'mq2007',
'flowers',
'voc2012',
]

@ -25,8 +25,12 @@ import glob
import cPickle as pickle
__all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader',
'convert'
'DATA_HOME',
'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
]
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
@ -58,12 +62,15 @@ def md5file(fname):
return hash_md5.hexdigest()
def download(url, module_name, md5sum):
def download(url, module_name, md5sum, save_name=None):
dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname):
os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1])
filename = os.path.join(dirname,
url.split('/')[-1]
if save_name is None else save_name)
retry = 0
retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum):
@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix):
Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param reader: a data reader, from which the convert program will read
data instances.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
:param max_lines_to_shuffle: the max lines numbers to shuffle before
writing.
"""
assert line_count >= 1

@ -0,0 +1,66 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.wmt16
import unittest
class TestWMT16(unittest.TestCase):
def checkout_one_sample(self, sample):
# train data has 3 field: source language word indices,
# target language word indices, and target next word indices.
self.assertEqual(len(sample), 3)
# test start mark and end mark in source word indices.
self.assertEqual(sample[0][0], 0)
self.assertEqual(sample[0][-1], 1)
# test start mask in target word indices
self.assertEqual(sample[1][0], 0)
# test en mask in target next word indices
self.assertEqual(sample[2][-1], 1)
def test_train(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.train(
src_dict_size=100000, trg_dict_size=100000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_test(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.test(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_val(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.validation(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_get_dict(self):
dict_size = 1000
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
self.assertEqual(len(word_dict), dict_size)
self.assertEqual(word_dict[0], "<s>")
self.assertEqual(word_dict[1], "<e>")
self.assertEqual(word_dict[2], "<unk>")
if __name__ == "__main__":
unittest.main()

@ -25,12 +25,20 @@ import gzip
import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters
__all__ = ['train', 'test', 'build_dict', 'convert']
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
__all__ = [
'train',
'test',
'get_dict',
'convert',
]
URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz')
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later.
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
# this is a small set of data for test. The original data is too large and
# will be add later.
URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz')
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'

File diff suppressed because it is too large Load Diff

@ -19,13 +19,32 @@ import contextlib
from ..registry import autodoc
__all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard',
'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While',
'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array',
'array_to_lod_tensor', 'increment', 'array_write', 'create_array',
'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse',
'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank',
'ParallelDo', 'Print'
'split_lod_tensor',
'merge_lod_tensor',
'BlockGuard',
'BlockGuardWithCompletion',
'StaticRNNMemoryLink',
'WhileGuard',
'While',
'lod_rank_table',
'max_sequence_len',
'topk',
'lod_tensor_to_array',
'array_to_lod_tensor',
'increment',
'array_write',
'create_array',
'less_than',
'array_read',
'shrink_memory',
'array_length',
'IfElse',
'DynamicRNN',
'ConditionalBlock',
'StaticRNN',
'reorder_lod_tensor_by_rank',
'ParallelDo',
'Print',
]

Loading…
Cancel
Save