|
|
|
@ -464,5 +464,46 @@ class TestDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAsyncDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
def net_conf(self):
|
|
|
|
|
self.network_with_table(is_sparse=True, is_distributed=True)
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
config = fluid.DistributeTranspilerConfig()
|
|
|
|
|
config.sync_mode = False
|
|
|
|
|
|
|
|
|
|
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(pserver1.blocks), 6)
|
|
|
|
|
# 0 listen_and_serv
|
|
|
|
|
# 1 optimize for fc_w or fc_b adam
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
|
|
|
|
|
["adam", "scale", "scale"])
|
|
|
|
|
# 2 optimize for table sgd
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[2].ops], ["sgd"])
|
|
|
|
|
# 3 prefetch -> lookup_sparse_table for data0
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
|
|
|
|
|
["lookup_sparse_table"])
|
|
|
|
|
# 4 prefetch -> lookup_sparse_table for data1
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
|
|
|
|
|
["lookup_sparse_table"])
|
|
|
|
|
# 5 save table
|
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer(config)
|
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
|
print([op.type for op in trainer.blocks[0].ops])
|
|
|
|
|
ops = [
|
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
|
|
|
|
|
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
|
|
|
|
|
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
|
|
|
|
|
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
|
|
|
|
|
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
|
|
|
|
|
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
|
|
|
|
|
'sum', 'split_ids', 'send', 'recv', 'recv'
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|