|
|
@ -701,6 +701,7 @@ class DistributeTranspiler(object):
|
|
|
|
send_vars.append(var)
|
|
|
|
send_vars.append(var)
|
|
|
|
|
|
|
|
|
|
|
|
if self.sync_mode:
|
|
|
|
if self.sync_mode:
|
|
|
|
|
|
|
|
fetch_barrier_input = []
|
|
|
|
send_barrier_out = program.global_block().create_var(
|
|
|
|
send_barrier_out = program.global_block().create_var(
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
@ -718,6 +719,7 @@ class DistributeTranspiler(object):
|
|
|
|
"trainer_id": self.trainer_id,
|
|
|
|
"trainer_id": self.trainer_id,
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
fetch_barrier_input.append(send_barrier_out)
|
|
|
|
|
|
|
|
|
|
|
|
# step 3: insert recv op to receive parameters from parameter server
|
|
|
|
# step 3: insert recv op to receive parameters from parameter server
|
|
|
|
recv_vars = []
|
|
|
|
recv_vars = []
|
|
|
@ -788,12 +790,14 @@ class DistributeTranspiler(object):
|
|
|
|
OP_ROLE_VAR_ATTR_NAME:
|
|
|
|
OP_ROLE_VAR_ATTR_NAME:
|
|
|
|
[param_varname, recv_op_role_var_name]
|
|
|
|
[param_varname, recv_op_role_var_name]
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
|
|
|
fetch_barrier_input.extend(splited_var)
|
|
|
|
|
|
|
|
|
|
|
|
if self.sync_mode:
|
|
|
|
if self.sync_mode:
|
|
|
|
# form a WAW dependency
|
|
|
|
# form a WAW dependency
|
|
|
|
program.global_block().append_op(
|
|
|
|
program.global_block().append_op(
|
|
|
|
type="fetch_barrier",
|
|
|
|
type="fetch_barrier",
|
|
|
|
inputs={},
|
|
|
|
inputs={"X": fetch_barrier_input},
|
|
|
|
outputs={"Out": all_recv_outputs},
|
|
|
|
outputs={"Out": all_recv_outputs},
|
|
|
|
attrs={
|
|
|
|
attrs={
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|