|
|
|
@ -257,7 +257,7 @@ class DistributeTranspiler:
|
|
|
|
|
][0]
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
self.table_grad_list = [
|
|
|
|
|
self.trainer_side_table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, trainer_id, index),
|
|
|
|
@ -267,7 +267,7 @@ class DistributeTranspiler:
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
self.table_grad_list = [
|
|
|
|
|
self.trainer_side_table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.pserver_%d" % (table_grad_var.name, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
@ -648,11 +648,11 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': [program.global_block().vars[table_grad_name]]
|
|
|
|
|
},
|
|
|
|
|
outputs={"Out": self.table_grad_list})
|
|
|
|
|
outputs={"Out": self.trainer_side_table_grad_list})
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index + 2,
|
|
|
|
|
type="send_vars",
|
|
|
|
|
inputs={'X': self.table_grad_list},
|
|
|
|
|
inputs={'X': self.trainer_side_table_grad_list},
|
|
|
|
|
outputs={"RPCClient": rpc_client_var},
|
|
|
|
|
attrs={"sync_send": True,
|
|
|
|
|
"epmap": pserver_endpoints})
|
|
|
|
@ -717,7 +717,7 @@ class DistributeTranspiler:
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
# create grad vars in pserver program
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
table_grad_list = [
|
|
|
|
|
pserver_side_table_grad_list = [
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, index, pserver_index),
|
|
|
|
@ -727,18 +727,19 @@ class DistributeTranspiler:
|
|
|
|
|
for index in range(self.trainer_num)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# append sum op for table_grad_list
|
|
|
|
|
# append sum op for pserver_side_table_grad_list
|
|
|
|
|
table_opt_block.append_op(
|
|
|
|
|
type="sum",
|
|
|
|
|
inputs={"X": table_grad_list},
|
|
|
|
|
inputs={"X": pserver_side_table_grad_list},
|
|
|
|
|
outputs={"Out": [grad_var]})
|
|
|
|
|
else:
|
|
|
|
|
# in async_mode, for table gradient, it also need to be splited to each parameter server
|
|
|
|
|
origin_grad_name = grad_var.name
|
|
|
|
|
splited_grad_name = self.table_grad_list[pserver_index].name
|
|
|
|
|
splited_grad_name = self.trainer_side_table_grad_list[
|
|
|
|
|
pserver_index].name
|
|
|
|
|
if not splited_grad_name.startswith(origin_grad_name):
|
|
|
|
|
raise ValueError("origin_grad_var: " + splited_grad_name +
|
|
|
|
|
" grad_var:" + grad_var.name)
|
|
|
|
|
" grad_var:" + grad_var.name)
|
|
|
|
|
grad_var = pserver_program.global_block().rename_var(
|
|
|
|
|
origin_grad_name, splited_grad_name)
|
|
|
|
|
|
|
|
|
|