optimize the name of table_grad_list

release/0.13.0
qiaolongfei 7 years ago
parent 16027ea111
commit 21f068ab19

@ -257,7 +257,7 @@ class DistributeTranspiler:
][0]
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.table_grad_list = [
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, trainer_id, index),
@ -267,7 +267,7 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints))
]
else:
self.table_grad_list = [
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
@ -648,11 +648,11 @@ class DistributeTranspiler:
inputs={
'Ids': [program.global_block().vars[table_grad_name]]
},
outputs={"Out": self.table_grad_list})
outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op(
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
inputs={'X': self.trainer_side_table_grad_list},
outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True,
"epmap": pserver_endpoints})
@ -717,7 +717,7 @@ class DistributeTranspiler:
if self.sync_mode:
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_side_table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
@ -727,18 +727,19 @@ class DistributeTranspiler:
for index in range(self.trainer_num)
]
# append sum op for table_grad_list
# append sum op for pserver_side_table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
inputs={"X": pserver_side_table_grad_list},
outputs={"Out": [grad_var]})
else:
# in async_mode, for table gradient, it also need to be splited to each parameter server
origin_grad_name = grad_var.name
splited_grad_name = self.table_grad_list[pserver_index].name
splited_grad_name = self.trainer_side_table_grad_list[
pserver_index].name
if not splited_grad_name.startswith(origin_grad_name):
raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name)
" grad_var:" + grad_var.name)
grad_var = pserver_program.global_block().rename_var(
origin_grad_name, splited_grad_name)

Loading…
Cancel
Save