|
|
|
@ -477,12 +477,23 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
# delete table init op
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
trainer_table_param_init_op = []
|
|
|
|
|
table_var = self.startup_program.global_block().vars[
|
|
|
|
|
self.table_name]
|
|
|
|
|
table_param_init_op = []
|
|
|
|
|
for op in self.startup_program.global_block().ops:
|
|
|
|
|
if self.table_name in op.output_arg_names:
|
|
|
|
|
trainer_table_param_init_op.append(op)
|
|
|
|
|
delete_ops(self.startup_program.global_block(),
|
|
|
|
|
trainer_table_param_init_op)
|
|
|
|
|
table_param_init_op.append(op)
|
|
|
|
|
init_op_num = len(table_param_init_op)
|
|
|
|
|
if init_op_num != 1:
|
|
|
|
|
raise ValueError("table init op num should be 1, now is " + str(
|
|
|
|
|
init_op_num))
|
|
|
|
|
table_init_op = table_param_init_op[1]
|
|
|
|
|
self.startup_program.global_block().append_op(
|
|
|
|
|
type="fake_init",
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={"Out": table_var},
|
|
|
|
|
attrs={"shape": table_init_op.attr('shape')})
|
|
|
|
|
delete_ops(self.startup_program.global_block(), table_param_init_op)
|
|
|
|
|
|
|
|
|
|
self.origin_program.__str__()
|
|
|
|
|
|
|
|
|
|