|
|
|
@ -18,6 +18,7 @@ import unittest
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.transpiler.distribute_transpiler import delete_ops
|
|
|
|
|
import traceback
|
|
|
|
|
import collections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranspilerTest(unittest.TestCase):
|
|
|
|
@ -53,9 +54,18 @@ class TranspilerTest(unittest.TestCase):
|
|
|
|
|
self.origin_prog = main.clone()
|
|
|
|
|
return main
|
|
|
|
|
|
|
|
|
|
def get_trainer(self, config=None, sync_mode=True):
|
|
|
|
|
t = self._transpiler_instance(config, sync_mode)
|
|
|
|
|
return t.get_trainer_program()
|
|
|
|
|
def get_trainer(self, config=None):
|
|
|
|
|
src = fluid.default_startup_program().clone()
|
|
|
|
|
|
|
|
|
|
t = self._transpiler_instance(config)
|
|
|
|
|
|
|
|
|
|
trainer_main = t.get_trainer_program()
|
|
|
|
|
trainer_startup = fluid.default_startup_program()
|
|
|
|
|
|
|
|
|
|
assert (src.num_blocks == 1)
|
|
|
|
|
assert (trainer_startup.num_blocks == src.num_blocks)
|
|
|
|
|
|
|
|
|
|
return trainer_main, trainer_startup
|
|
|
|
|
|
|
|
|
|
def get_pserver(self, ep, config=None, sync_mode=True):
|
|
|
|
|
t = self._transpiler_instance(config, sync_mode)
|
|
|
|
@ -91,7 +101,21 @@ class TestBasicModel(TranspilerTest):
|
|
|
|
|
pserver, startup = self.get_pserver(self.pserver1_ep)
|
|
|
|
|
pserver2, startup2 = self.get_pserver(self.pserver2_ep)
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, trainer_startup = self.get_trainer()
|
|
|
|
|
|
|
|
|
|
# splited var blocks should be in startup program
|
|
|
|
|
self.assertTrue("fc_w.block0" in trainer_startup.global_block().vars)
|
|
|
|
|
self.assertTrue("fc_w.block1" in trainer_startup.global_block().vars)
|
|
|
|
|
self.assertTrue("fc_w" in trainer_startup.global_block().vars)
|
|
|
|
|
self.assertTrue("fc_b" in trainer_startup.global_block().vars)
|
|
|
|
|
self.assertTrue("fc_w@GRAD" not in trainer_startup.global_block().vars)
|
|
|
|
|
self.assertTrue("fc_b@GRAD" not in trainer_startup.global_block().vars)
|
|
|
|
|
|
|
|
|
|
src = [op.type for op in trainer_startup.global_block().ops]
|
|
|
|
|
dst = ['fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv', \
|
|
|
|
|
'fetch_barrier', 'concat']
|
|
|
|
|
|
|
|
|
|
self.assertEqual(src, dst)
|
|
|
|
|
|
|
|
|
|
self.assertEqual([op.type for op in trainer.global_block().ops], [
|
|
|
|
|
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
|
|
|
|
@ -142,7 +166,7 @@ class TestBasicModelWithLargeBlockSize(TranspilerTest):
|
|
|
|
|
pserver, startup = self.get_pserver(self.pserver1_ep, config)
|
|
|
|
|
pserver2, startup2 = self.get_pserver(self.pserver2_ep, config)
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer(config)
|
|
|
|
|
trainer, _ = self.get_trainer(config)
|
|
|
|
|
|
|
|
|
|
self.assertEqual([op.type for op in trainer.global_block().ops], [
|
|
|
|
|
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
|
|
|
|
@ -226,7 +250,7 @@ class TestLRDecay(TranspilerTest):
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
pserver, startup = self.get_pserver(self.pserver1_ep)
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(pserver.blocks), 4)
|
|
|
|
|
lr_decay_ops = [op.type for op in pserver.blocks[1].ops]
|
|
|
|
@ -256,7 +280,7 @@ class TestLRDecayConditional(TranspilerTest):
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
pserver, startup = self.get_pserver(self.pserver1_ep)
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
|
|
|
|
|
serv_op = pserver.blocks[0].ops[0]
|
|
|
|
|
sub_blocks = []
|
|
|
|
@ -305,7 +329,7 @@ class TestL2Decay(TranspilerTest):
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
pserver, startup = self.get_pserver(self.pserver1_ep)
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(pserver.blocks), 3)
|
|
|
|
|
self.assertEqual([op.type for op in pserver.blocks[1].ops],
|
|
|
|
@ -340,7 +364,7 @@ class TestL2DecayWithPiecewise(TranspilerTest):
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
pserver, startup = self.get_pserver(self.pserver1_ep)
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(pserver.blocks), 9)
|
|
|
|
|
self.assertEqual([op.type for op in pserver.blocks[1].ops], [
|
|
|
|
@ -415,7 +439,7 @@ class TestLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
|
|
|
|
|
["sum", "adam", "scale", "scale"])
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
ops = [
|
|
|
|
|
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
|
|
|
|
@ -453,7 +477,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
# 5 save table
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
ops = [
|
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
|
|
|
|
@ -486,7 +510,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
|
|
|
|
|
["adam", "scale", "scale"])
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer(config)
|
|
|
|
|
trainer, _ = self.get_trainer(config)
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
ops = [
|
|
|
|
|
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
|
|
|
|
@ -525,7 +549,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
# 5 save table
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer(config)
|
|
|
|
|
trainer, _ = self.get_trainer(config)
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
ops = [
|
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
|
|
|
|
|