|
|
|
@ -474,6 +474,15 @@ class DistributeTranspiler(object):
|
|
|
|
|
delete_ops(self.origin_program.global_block(), self.optimize_ops)
|
|
|
|
|
delete_ops(self.origin_program.global_block(), lr_ops)
|
|
|
|
|
|
|
|
|
|
# delete table init op
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
trainer_table_param_init_op = []
|
|
|
|
|
for op in self.startup_program.global_block().ops:
|
|
|
|
|
if self.table_name in op.output_arg_names:
|
|
|
|
|
trainer_table_param_init_op.append(op)
|
|
|
|
|
delete_ops(self.startup_program.global_block(),
|
|
|
|
|
trainer_table_param_init_op)
|
|
|
|
|
|
|
|
|
|
self.origin_program.__str__()
|
|
|
|
|
|
|
|
|
|
if wait_port:
|
|
|
|
@ -1194,9 +1203,8 @@ to transpile() call.")
|
|
|
|
|
# create table param and grad var in pserver program
|
|
|
|
|
# create table optimize block in pserver program
|
|
|
|
|
table_opt_op = [
|
|
|
|
|
op for op in self.optimize_ops
|
|
|
|
|
if 'Param' in op.input_names and op.input("Param")[0] ==
|
|
|
|
|
self.table_name
|
|
|
|
|
op for op in self.optimize_ops if 'Param' in op.input_names and
|
|
|
|
|
op.input("Param")[0] == self.table_name
|
|
|
|
|
][0]
|
|
|
|
|
|
|
|
|
|
origin_param_var = self.origin_program.global_block().vars[
|
|
|
|
|