Merge pull request #7661 from lcy-seso/wmt16_en_ger

Add WMT16 dataset.
add_depthwiseConv_op_gpu
Cao Ying 7 years ago committed by GitHub
commit 430fdc52a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,11 +24,23 @@ import conll05
import uci_housing import uci_housing
import sentiment import sentiment
import wmt14 import wmt14
import wmt16
import mq2007 import mq2007
import flowers import flowers
import voc2012 import voc2012
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist',
'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' 'imikolov',
'imdb',
'cifar',
'movielens',
'conll05',
'sentiment'
'uci_housing',
'wmt14',
'wmt16',
'mq2007',
'flowers',
'voc2012',
] ]

@ -25,8 +25,12 @@ import glob
import cPickle as pickle import cPickle as pickle
__all__ = [ __all__ = [
'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader', 'DATA_HOME',
'convert' 'download',
'md5file',
'split',
'cluster_files_reader',
'convert',
] ]
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
@ -58,12 +62,15 @@ def md5file(fname):
return hash_md5.hexdigest() return hash_md5.hexdigest()
def download(url, module_name, md5sum): def download(url, module_name, md5sum, save_name=None):
dirname = os.path.join(DATA_HOME, module_name) dirname = os.path.join(DATA_HOME, module_name)
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
filename = os.path.join(dirname, url.split('/')[-1]) filename = os.path.join(dirname,
url.split('/')[-1]
if save_name is None else save_name)
retry = 0 retry = 0
retry_limit = 3 retry_limit = 3
while not (os.path.exists(filename) and md5file(filename) == md5sum): while not (os.path.exists(filename) and md5file(filename) == md5sum):
@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix):
Convert data from reader to recordio format files. Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved. :param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances. :param reader: a data reader, from which the convert program will read
data instances.
:param name_prefix: the name prefix of generated files. :param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing. :param max_lines_to_shuffle: the max lines numbers to shuffle before
writing.
""" """
assert line_count >= 1 assert line_count >= 1

@ -0,0 +1,66 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.wmt16
import unittest
class TestWMT16(unittest.TestCase):
def checkout_one_sample(self, sample):
# train data has 3 field: source language word indices,
# target language word indices, and target next word indices.
self.assertEqual(len(sample), 3)
# test start mark and end mark in source word indices.
self.assertEqual(sample[0][0], 0)
self.assertEqual(sample[0][-1], 1)
# test start mask in target word indices
self.assertEqual(sample[1][0], 0)
# test en mask in target next word indices
self.assertEqual(sample[2][-1], 1)
def test_train(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.train(
src_dict_size=100000, trg_dict_size=100000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_test(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.test(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_val(self):
for idx, sample in enumerate(
paddle.v2.dataset.wmt16.validation(
src_dict_size=1000, trg_dict_size=1000)()):
if idx >= 10: break
self.checkout_one_sample(sample)
def test_get_dict(self):
dict_size = 1000
word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True)
self.assertEqual(len(word_dict), dict_size)
self.assertEqual(word_dict[0], "<s>")
self.assertEqual(word_dict[1], "<e>")
self.assertEqual(word_dict[2], "<unk>")
if __name__ == "__main__":
unittest.main()

@ -25,12 +25,20 @@ import gzip
import paddle.v2.dataset.common import paddle.v2.dataset.common
from paddle.v2.parameters import Parameters from paddle.v2.parameters import Parameters
__all__ = ['train', 'test', 'build_dict', 'convert'] __all__ = [
'train',
URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' 'test',
'get_dict',
'convert',
]
URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/'
'cslm_joint_paper/data/dev+test.tgz')
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
# this is a small set of data for test. The original data is too large and will be add later. # this is a small set of data for test. The original data is too large and
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' # will be add later.
URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/'
'wmt_shrinked_data/wmt14.tgz')
MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c'
# BLEU of this trained model is 26.92 # BLEU of this trained model is 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
@ -42,8 +50,8 @@ UNK = "<unk>"
UNK_IDX = 2 UNK_IDX = 2
def __read_to_dict__(tar_file, dict_size): def __read_to_dict(tar_file, dict_size):
def __to_dict__(fd, size): def __to_dict(fd, size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
if line_count < size: if line_count < size:
@ -58,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size):
if each_item.name.endswith("src.dict") if each_item.name.endswith("src.dict")
] ]
assert len(names) == 1 assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size) src_dict = __to_dict(f.extractfile(names[0]), dict_size)
names = [ names = [
each_item.name for each_item in f each_item.name for each_item in f
if each_item.name.endswith("trg.dict") if each_item.name.endswith("trg.dict")
] ]
assert len(names) == 1 assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size) trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict return src_dict, trg_dict
def reader_creator(tar_file, file_name, dict_size): def reader_creator(tar_file, file_name, dict_size):
def reader(): def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f: with tarfile.open(tar_file, mode='r') as f:
names = [ names = [
each_item.name for each_item in f each_item.name for each_item in f
@ -152,7 +160,7 @@ def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...} # if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...} # else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in src_dict.items()} src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()} trg_dict = {v: k for k, v in trg_dict.items()}

File diff suppressed because it is too large Load Diff

@ -20,13 +20,32 @@ import contextlib
from ..registry import autodoc from ..registry import autodoc
__all__ = [ __all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'split_lod_tensor',
'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While', 'merge_lod_tensor',
'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array', 'BlockGuard',
'array_to_lod_tensor', 'increment', 'array_write', 'create_array', 'BlockGuardWithCompletion',
'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse', 'StaticRNNMemoryLink',
'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank', 'WhileGuard',
'ParallelDo', 'Print' 'While',
'lod_rank_table',
'max_sequence_len',
'topk',
'lod_tensor_to_array',
'array_to_lod_tensor',
'increment',
'array_write',
'create_array',
'less_than',
'array_read',
'shrink_memory',
'array_length',
'IfElse',
'DynamicRNN',
'ConditionalBlock',
'StaticRNN',
'reorder_lod_tensor_by_rank',
'ParallelDo',
'Print',
] ]

Loading…
Cancel
Save