|
|
@ -25,6 +25,8 @@ import paddle.dataset.wmt16 as wmt16
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
from feed_data_reader import FeedDataReader
|
|
|
|
from feed_data_reader import FeedDataReader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ['CPU_NUM'] = str(4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelHyperParams(object):
|
|
|
|
class ModelHyperParams(object):
|
|
|
|
# Dictionary size for source and target language. This model directly uses
|
|
|
|
# Dictionary size for source and target language. This model directly uses
|
|
|
@ -185,10 +187,6 @@ def get_feed_data_reader():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestTransformer(TestParallelExecutorBase):
|
|
|
|
class TestTransformer(TestParallelExecutorBase):
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def setUpClass(cls):
|
|
|
|
|
|
|
|
os.environ['CPU_NUM'] = str(4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_main(self):
|
|
|
|
def test_main(self):
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
self.check_network_convergence(
|
|
|
|
self.check_network_convergence(
|
|
|
|