|
|
|
@ -15,17 +15,55 @@
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import paddle
|
|
|
|
|
from test_dist_base import TestDistBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistTransformer2x2(TestDistBase):
|
|
|
|
|
def download_files():
|
|
|
|
|
url_prefix = 'http://paddle-unittest-data.cdn.bcebos.com/dist_transformer/'
|
|
|
|
|
vocab_url = url_prefix + 'vocab.bpe.32000'
|
|
|
|
|
vocab_md5 = 'a86d345ca6e27f6591d0dccb1b9be853'
|
|
|
|
|
paddle.dataset.common.download(vocab_url, 'test_dist_transformer',
|
|
|
|
|
vocab_md5)
|
|
|
|
|
|
|
|
|
|
local_train_url = url_prefix + 'train.tok.clean.bpe.32000.en-de'
|
|
|
|
|
local_train_md5 = '033eb02b9449e6dd823f050782ac8914'
|
|
|
|
|
paddle.dataset.common.download(local_train_url, 'test_dist_transformer',
|
|
|
|
|
local_train_md5)
|
|
|
|
|
|
|
|
|
|
train0_url = url_prefix + 'train.tok.clean.bpe.32000.en-de.train_0'
|
|
|
|
|
train0_md5 = 'ddce7f602f352a0405267285379a38b1'
|
|
|
|
|
paddle.dataset.common.download(train0_url, 'test_dist_transformer',
|
|
|
|
|
train0_md5)
|
|
|
|
|
|
|
|
|
|
train1_url = url_prefix + 'train.tok.clean.bpe.32000.en-de.train_1'
|
|
|
|
|
train1_md5 = '8757798200180285b1a619cd7f408747'
|
|
|
|
|
paddle.dataset.common.download(train1_url, 'test_dist_transformer',
|
|
|
|
|
train1_md5)
|
|
|
|
|
|
|
|
|
|
test_url = url_prefix + 'newstest2013.tok.bpe.32000.en-de'
|
|
|
|
|
test_md5 = '9dd74a266dbdb25314183899f269b4a2'
|
|
|
|
|
paddle.dataset.common.download(test_url, 'test_dist_transformer', test_md5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistTransformer2x2Sync(TestDistBase):
|
|
|
|
|
def _setup_config(self):
|
|
|
|
|
self._sync_mode = True
|
|
|
|
|
|
|
|
|
|
def test_transformer(self):
|
|
|
|
|
# TODO(paddle-dev): check if the delta is OK.
|
|
|
|
|
# Usually start around ~8000 and converge to ~5000
|
|
|
|
|
self.check_with_place("dist_transformer.py", delta=400)
|
|
|
|
|
download_files()
|
|
|
|
|
#Note: loss on test dataset of the first 5 batch are:
|
|
|
|
|
# 10.518872, 10.518871, 10.518868, 10.518862, 10.518855
|
|
|
|
|
self.check_with_place("dist_transformer.py", delta=1e-7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistTransformer2x2Async(TestDistBase):
|
|
|
|
|
def _setup_config(self):
|
|
|
|
|
self._sync_mode = False
|
|
|
|
|
|
|
|
|
|
def test_transformer(self):
|
|
|
|
|
download_files()
|
|
|
|
|
self.check_with_place("dist_transformer.py", delta=1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|