|
|
|
@ -497,7 +497,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
# 5 save table
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
|
|
|
|
|
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
trainer, trainer_startup = self.get_trainer()
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
ops = [
|
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
|
|
|
|
@ -511,6 +511,16 @@ class TestDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
|
|
|
|
|
|
|
|
|
startup_ops = [
|
|
|
|
|
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
|
|
|
|
|
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
|
|
|
|
|
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
|
|
|
|
|
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv',
|
|
|
|
|
'fetch_barrier', 'fake_init'
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
|
|
|
|
|
startup_ops)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAsyncLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
|
def net_conf(self):
|
|
|
|
|