|
|
|
@ -15,8 +15,10 @@
|
|
|
|
|
wmt14 dataset
|
|
|
|
|
"""
|
|
|
|
|
import tarfile
|
|
|
|
|
import gzip
|
|
|
|
|
|
|
|
|
|
from paddle.v2.dataset.common import download
|
|
|
|
|
from paddle.v2.parameters import Parameters
|
|
|
|
|
|
|
|
|
|
__all__ = ['train', 'test', 'build_dict']
|
|
|
|
|
|
|
|
|
@ -25,6 +27,9 @@ MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
|
|
|
|
|
# this is a small set of data for test. The original data is too large and will be add later.
|
|
|
|
|
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
|
|
|
|
|
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
|
|
|
|
|
# this is the pretrained model, whose bleu = 26.92
|
|
|
|
|
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
|
|
|
|
|
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
|
|
|
|
|
|
|
|
|
|
START = "<s>"
|
|
|
|
|
END = "<e>"
|
|
|
|
@ -103,5 +108,13 @@ def test(dict_size):
|
|
|
|
|
download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model():
|
|
|
|
|
tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
|
|
|
|
|
with gzip.open(tar_file, 'r') as f:
|
|
|
|
|
parameters = Parameters.from_tar(f)
|
|
|
|
|
return parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch():
|
|
|
|
|
download(URL_TRAIN, 'wmt14', MD5_TRAIN)
|
|
|
|
|
download(URL_MODEL, 'wmt14', MD5_MODEL)
|
|
|
|
|