|
|
@ -169,7 +169,7 @@ def append_send_ops_pass(program, config):
|
|
|
|
trainer_id = config.get_role_id()
|
|
|
|
trainer_id = config.get_role_id()
|
|
|
|
pserver_endpoints = config.get_ps_endpoints()
|
|
|
|
pserver_endpoints = config.get_ps_endpoints()
|
|
|
|
|
|
|
|
|
|
|
|
def _append_send_op(union_vars, queue):
|
|
|
|
def _append_grad_send_op(union_vars, queue):
|
|
|
|
|
|
|
|
|
|
|
|
if queue == STEP_COUNTER:
|
|
|
|
if queue == STEP_COUNTER:
|
|
|
|
send_input_vars = []
|
|
|
|
send_input_vars = []
|
|
|
@ -198,6 +198,43 @@ def append_send_ops_pass(program, config):
|
|
|
|
|
|
|
|
|
|
|
|
return dummy_output
|
|
|
|
return dummy_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_sparse_ids_send_op():
|
|
|
|
|
|
|
|
sparse_var = []
|
|
|
|
|
|
|
|
sparse_tables = []
|
|
|
|
|
|
|
|
unique_sparse_var = {}
|
|
|
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
|
|
|
|
if "is_sparse" in op.all_attrs():
|
|
|
|
|
|
|
|
if op.type == "lookup_table":
|
|
|
|
|
|
|
|
op._set_attr('remote_prefetch', False)
|
|
|
|
|
|
|
|
for input_var_name, sparse_var_name in zip(
|
|
|
|
|
|
|
|
op.input("Ids"), op.input("W")):
|
|
|
|
|
|
|
|
if input_var_name in unique_sparse_var:
|
|
|
|
|
|
|
|
if unique_sparse_var[input_var_name] == sparse_var_name:
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
input_var = program.global_block().var(input_var_name)
|
|
|
|
|
|
|
|
sparse_var.append(input_var)
|
|
|
|
|
|
|
|
sparse_tables.append(sparse_var_name)
|
|
|
|
|
|
|
|
unique_sparse_var[input_var_name] = sparse_var_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dummy_output = []
|
|
|
|
|
|
|
|
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
|
|
|
|
|
|
|
|
dummy_output = program.global_block().create_var(
|
|
|
|
|
|
|
|
name=framework.generate_control_dev_var_name())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
|
|
|
|
type="send",
|
|
|
|
|
|
|
|
inputs={"X": sparse_var},
|
|
|
|
|
|
|
|
outputs={"Out": dummy_output},
|
|
|
|
|
|
|
|
attrs={
|
|
|
|
|
|
|
|
"send_varnames": sparse_tables,
|
|
|
|
|
|
|
|
"merge_add": True,
|
|
|
|
|
|
|
|
"use_send_handler": False,
|
|
|
|
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return dummy_output
|
|
|
|
|
|
|
|
|
|
|
|
def _append_barrier_op(dummys):
|
|
|
|
def _append_barrier_op(dummys):
|
|
|
|
program.global_block().append_op(
|
|
|
|
program.global_block().append_op(
|
|
|
|
type="send_barrier",
|
|
|
|
type="send_barrier",
|
|
|
@ -214,8 +251,12 @@ def append_send_ops_pass(program, config):
|
|
|
|
|
|
|
|
|
|
|
|
sends = config.get_trainer_send_context()
|
|
|
|
sends = config.get_trainer_send_context()
|
|
|
|
|
|
|
|
|
|
|
|
for merged_name, send in sends.items():
|
|
|
|
if mode == DistributedMode.GEO:
|
|
|
|
dummys.append(_append_send_op(send.origin_varnames(), merged_name))
|
|
|
|
dummys.append(_append_sparse_ids_send_op())
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
for merged_name, send in sends.items():
|
|
|
|
|
|
|
|
dummys.append(
|
|
|
|
|
|
|
|
_append_grad_send_op(send.origin_varnames(), merged_name))
|
|
|
|
|
|
|
|
|
|
|
|
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
|
|
|
|
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
|
|
|
|
_append_barrier_op(dummys)
|
|
|
|
_append_barrier_op(dummys)
|
|
|
|