|
|
|
@ -521,7 +521,7 @@ class TestLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
|
'split_selected_rows', 'send', 'sequence_pool_grad',
|
|
|
|
|
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
|
|
|
|
|
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
|
|
|
|
|
'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'concat'
|
|
|
|
|
'recv', 'fetch_barrier'
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
|
|
|
|
|
|
|
|
@ -608,8 +608,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
|
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
|
|
|
|
|
'split_selected_rows', 'send', 'sequence_pool_grad',
|
|
|
|
|
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
|
|
|
|
|
'sum', 'split_selected_rows', 'send', 'recv', 'recv', 'recv',
|
|
|
|
|
'recv', 'concat', 'concat'
|
|
|
|
|
'sum', 'split_selected_rows', 'send', 'recv', 'recv'
|
|
|
|
|
]
|
|
|
|
|
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
|
|
|
|
|
|
|
|
|