replace concat_op with merge_ids_op

wangkuiyi-patch-1
qiaolongfei 7 years ago
parent 509cb0bc76
commit a941786393

@ -618,7 +618,7 @@ class DistributeTranspiler:
if op.type == LOOKUP_TABLE_TYPE:
continue_search_lookup_table_op = True
op_index = list(all_ops).index(op)
lookup_table_op_index = list(all_ops).index(op)
ids_name = op.input("Ids")
out_name = op.output("Out")
@ -637,7 +637,7 @@ class DistributeTranspiler:
# insert split_ids_op
program.global_block().insert_op(
index=op_index,
index=lookup_table_op_index,
type="split_ids",
inputs={
'Ids': [
@ -649,7 +649,7 @@ class DistributeTranspiler:
# insert prefetch_op
program.global_block().insert_op(
index=op_index + 1,
index=lookup_table_op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={"Out": self.prefetch_output_vars},
@ -660,16 +660,21 @@ class DistributeTranspiler:
# insert concat_op
program.global_block().insert_op(
index=op_index + 2,
type="concat",
inputs={'X': self.prefetch_output_vars},
index=lookup_table_op_index + 2,
type="merge_ids",
inputs={
'Ids': [
program.global_block().vars[varname]
for varname in ids_name
],
'X': self.prefetch_output_vars
},
outputs={
"Out": [
program.global_block().vars[varname]
for varname in out_name
]
},
attrs={"axis": 0})
})
# delete lookup_table_op
delete_ops(program.global_block(), [op])

Loading…
Cancel
Save