fix dist table send hang problem (#13259)

* fix dist table send hang problem

* revert sync_mode config

* fix async send table
upload-readme
Qiao Longfei 7 years ago committed by GitHub
parent 2c31ea9293
commit 020d13c18a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -247,7 +247,7 @@ class DistributeTranspiler(object):
np.random.seed(self.origin_program.random_seed)
np.random.shuffle(grad_var_mapping_items)
grad_name_to_send_dummy_out = dict()
self.grad_name_to_send_dummy_out = dict()
for grad_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars)
@ -271,7 +271,7 @@ class DistributeTranspiler(object):
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send
@ -297,7 +297,12 @@ class DistributeTranspiler(object):
if self.sync_mode:
send_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = grad_name_to_send_dummy_out.values()
if self.has_distributed_lookup_table:
self.grad_name_to_send_dummy_out[
self.table_name] = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
input_deps = self.grad_name_to_send_dummy_out.values()
program.global_block().append_op(
type="send_barrier",
inputs={"X": list(input_deps)},
@ -329,7 +334,7 @@ class DistributeTranspiler(object):
recv_dep_in = send_barrier_out
else:
# connect deps to send op in async mode
recv_dep_in = grad_name_to_send_dummy_out[
recv_dep_in = self.grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
@ -1046,9 +1051,13 @@ class DistributeTranspiler(object):
index=op_index + 2,
type="send",
inputs={'X': self.trainer_side_table_grad_list},
outputs={'Out': []},
outputs={
'Out':
[self.grad_name_to_send_dummy_out[self.table_name]]
if self.sync_mode else []
},
attrs={
"sync_mode": True,
"sync_mode": False,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [

Loading…
Cancel
Save