|
|
|
@ -641,7 +641,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
# 5 save table
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
|
|
|
|
|
|
|
|
|
|
trainer, _ = self.get_trainer(config)
|
|
|
|
|
trainer, trainer_startup = self.get_trainer(config)
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
ops = [
|
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
|
|
|
|
@ -655,6 +655,16 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
'recv', 'concat'
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
|
|
|
|
startup_ops = [
|
|
|
|
|
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
|
|
|
|
|
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
|
|
|
|
|
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
|
|
|
|
|
'fill_constant', 'fill_constant', 'uniform_random',
|
|
|
|
|
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
|
|
|
|
|
'fake_init'
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
|
|
|
|
|
startup_ops)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistLookupTableSliceSize(TestDistLookupTableBase):
|
|
|
|
|