|
|
|
@ -247,7 +247,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
np.random.seed(self.origin_program.random_seed)
|
|
|
|
|
np.random.shuffle(grad_var_mapping_items)
|
|
|
|
|
|
|
|
|
|
grad_name_to_send_dummy_out = dict()
|
|
|
|
|
self.grad_name_to_send_dummy_out = dict()
|
|
|
|
|
for grad_varname, splited_vars in grad_var_mapping_items:
|
|
|
|
|
eplist = ps_dispatcher.dispatch(splited_vars)
|
|
|
|
|
|
|
|
|
@ -271,7 +271,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
dummy_output = program.global_block().create_var(
|
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
|
grad_name_to_send_dummy_out[grad_varname] = dummy_output
|
|
|
|
|
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
|
|
|
|
|
|
|
|
|
|
# get send op_role_var, if not splited, the grad should have .trainer suffix
|
|
|
|
|
# if splited, grad should be the original grad var name (split_by_ref and send
|
|
|
|
@ -297,7 +297,12 @@ class DistributeTranspiler(object):
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
send_barrier_out = program.global_block().create_var(
|
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
|
input_deps = grad_name_to_send_dummy_out.values()
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
self.grad_name_to_send_dummy_out[
|
|
|
|
|
self.table_name] = program.global_block().create_var(
|
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
|
input_deps = self.grad_name_to_send_dummy_out.values()
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
type="send_barrier",
|
|
|
|
|
inputs={"X": list(input_deps)},
|
|
|
|
@ -329,7 +334,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
recv_dep_in = send_barrier_out
|
|
|
|
|
else:
|
|
|
|
|
# connect deps to send op in async mode
|
|
|
|
|
recv_dep_in = grad_name_to_send_dummy_out[
|
|
|
|
|
recv_dep_in = self.grad_name_to_send_dummy_out[
|
|
|
|
|
self.param_name_to_grad_name[param_varname]]
|
|
|
|
|
all_recv_outputs.extend(splited_var)
|
|
|
|
|
# get recv op_role_var, if not splited, the grad should have .trainer suffix
|
|
|
|
@ -1046,9 +1051,13 @@ class DistributeTranspiler(object):
|
|
|
|
|
index=op_index + 2,
|
|
|
|
|
type="send",
|
|
|
|
|
inputs={'X': self.trainer_side_table_grad_list},
|
|
|
|
|
outputs={'Out': []},
|
|
|
|
|
outputs={
|
|
|
|
|
'Out':
|
|
|
|
|
[self.grad_name_to_send_dummy_out[self.table_name]]
|
|
|
|
|
if self.sync_mode else []
|
|
|
|
|
},
|
|
|
|
|
attrs={
|
|
|
|
|
"sync_mode": True,
|
|
|
|
|
"sync_mode": False,
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
|
|
|
|
|
OP_ROLE_VAR_ATTR_NAME: [
|
|
|
|
|