|
|
|
@ -618,7 +618,7 @@ class DistributeTranspiler:
|
|
|
|
|
if op.type == LOOKUP_TABLE_TYPE:
|
|
|
|
|
continue_search_lookup_table_op = True
|
|
|
|
|
|
|
|
|
|
op_index = list(all_ops).index(op)
|
|
|
|
|
lookup_table_op_index = list(all_ops).index(op)
|
|
|
|
|
ids_name = op.input("Ids")
|
|
|
|
|
out_name = op.output("Out")
|
|
|
|
|
|
|
|
|
@ -637,7 +637,7 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# insert split_ids_op
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index,
|
|
|
|
|
index=lookup_table_op_index,
|
|
|
|
|
type="split_ids",
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': [
|
|
|
|
@ -649,7 +649,7 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# insert prefetch_op
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index + 1,
|
|
|
|
|
index=lookup_table_op_index + 1,
|
|
|
|
|
type="prefetch",
|
|
|
|
|
inputs={'X': self.prefetch_input_vars},
|
|
|
|
|
outputs={"Out": self.prefetch_output_vars},
|
|
|
|
@ -660,16 +660,21 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# insert concat_op
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index + 2,
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={'X': self.prefetch_output_vars},
|
|
|
|
|
index=lookup_table_op_index + 2,
|
|
|
|
|
type="merge_ids",
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': [
|
|
|
|
|
program.global_block().vars[varname]
|
|
|
|
|
for varname in ids_name
|
|
|
|
|
],
|
|
|
|
|
'X': self.prefetch_output_vars
|
|
|
|
|
},
|
|
|
|
|
outputs={
|
|
|
|
|
"Out": [
|
|
|
|
|
program.global_block().vars[varname]
|
|
|
|
|
for varname in out_name
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# delete lookup_table_op
|
|
|
|
|
delete_ops(program.global_block(), [op])
|
|
|
|
|