|
|
|
@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
|
|
|
|
|
|
|
|
|
|
|
|
LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
|
|
|
|
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
|
|
|
|
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
|
|
|
|
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VarBlock:
|
|
|
|
class VarBlock:
|
|
|
|
@ -297,11 +299,6 @@ class DistributeTranspiler:
|
|
|
|
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
|
|
|
|
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
|
|
|
|
param_var_mapping[p_name][int(p_bid)]
|
|
|
|
param_var_mapping[p_name][int(p_bid)]
|
|
|
|
|
|
|
|
|
|
|
|
rpc_client_var = program.global_block().create_var(
|
|
|
|
|
|
|
|
name=RPC_CLIENT_VAR_NAME,
|
|
|
|
|
|
|
|
persistable=True,
|
|
|
|
|
|
|
|
type=core.VarDesc.VarType.RAW)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# step 3: transpile trainer side program, insert recv op and send op.
|
|
|
|
# step 3: transpile trainer side program, insert recv op and send op.
|
|
|
|
|
|
|
|
|
|
|
|
# create mapping of endpoint -> split var to create pserver side program
|
|
|
|
# create mapping of endpoint -> split var to create pserver side program
|
|
|
|
@ -338,8 +335,11 @@ class DistributeTranspiler:
|
|
|
|
index=index + 1,
|
|
|
|
index=index + 1,
|
|
|
|
type="send_vars",
|
|
|
|
type="send_vars",
|
|
|
|
inputs={"X": splited_vars},
|
|
|
|
inputs={"X": splited_vars},
|
|
|
|
outputs={"RPCClient": rpc_client_var},
|
|
|
|
outputs={},
|
|
|
|
attrs={"epmap": eplist})
|
|
|
|
attrs={
|
|
|
|
|
|
|
|
"epmap": eplist,
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
|
|
|
})
|
|
|
|
for _, var in enumerate(splited_vars):
|
|
|
|
for _, var in enumerate(splited_vars):
|
|
|
|
send_vars.append(var)
|
|
|
|
send_vars.append(var)
|
|
|
|
|
|
|
|
|
|
|
|
@ -347,10 +347,11 @@ class DistributeTranspiler:
|
|
|
|
program.global_block().append_op(
|
|
|
|
program.global_block().append_op(
|
|
|
|
type="send_barrier",
|
|
|
|
type="send_barrier",
|
|
|
|
inputs={},
|
|
|
|
inputs={},
|
|
|
|
outputs={"RPCClient": rpc_client_var},
|
|
|
|
outputs={},
|
|
|
|
attrs={
|
|
|
|
attrs={
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
"sync_mode": self.sync_mode
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# step 3.2: insert recv op to receive parameters from parameter server
|
|
|
|
# step 3.2: insert recv op to receive parameters from parameter server
|
|
|
|
@ -373,15 +374,20 @@ class DistributeTranspiler:
|
|
|
|
program.global_block().append_op(
|
|
|
|
program.global_block().append_op(
|
|
|
|
type="recv",
|
|
|
|
type="recv",
|
|
|
|
inputs={},
|
|
|
|
inputs={},
|
|
|
|
outputs={"Out": splited_var,
|
|
|
|
outputs={"Out": splited_var},
|
|
|
|
"RPCClient": rpc_client_var},
|
|
|
|
attrs={
|
|
|
|
attrs={"epmap": eps})
|
|
|
|
"epmap": eps,
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
program.global_block().append_op(
|
|
|
|
program.global_block().append_op(
|
|
|
|
type="fetch_barrier",
|
|
|
|
type="fetch_barrier",
|
|
|
|
inputs={},
|
|
|
|
inputs={},
|
|
|
|
outputs={"RPCClient": rpc_client_var},
|
|
|
|
outputs={},
|
|
|
|
attrs={"endpoints": pserver_endpoints})
|
|
|
|
attrs={
|
|
|
|
|
|
|
|
"endpoints": pserver_endpoints,
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
for varname, splited_var in param_var_mapping.iteritems():
|
|
|
|
for varname, splited_var in param_var_mapping.iteritems():
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
if len(splited_var) <= 1:
|
|
|
|
@ -394,10 +400,8 @@ class DistributeTranspiler:
|
|
|
|
attrs={"axis": 0})
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
|
|
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program, eplist)
|
|
|
|
eplist)
|
|
|
|
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
|
|
|
|
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
|
|
|
|
|
|
|
|
pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_trainer_program(self):
|
|
|
|
def get_trainer_program(self):
|
|
|
|
# remove optimize ops and add a send op to main_program
|
|
|
|
# remove optimize ops and add a send op to main_program
|
|
|
|
@ -617,8 +621,7 @@ class DistributeTranspiler:
|
|
|
|
return s_prog
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program, eplist):
|
|
|
|
eplist):
|
|
|
|
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
self.prefetch_input_vars = None
|
|
|
|
self.prefetch_input_vars = None
|
|
|
|
self.prefetch_output_vars = None
|
|
|
|
self.prefetch_output_vars = None
|
|
|
|
@ -665,11 +668,11 @@ class DistributeTranspiler:
|
|
|
|
index=op_index + 1,
|
|
|
|
index=op_index + 1,
|
|
|
|
type="prefetch",
|
|
|
|
type="prefetch",
|
|
|
|
inputs={'X': self.prefetch_input_vars},
|
|
|
|
inputs={'X': self.prefetch_input_vars},
|
|
|
|
outputs={
|
|
|
|
outputs={"Out": self.prefetch_output_vars},
|
|
|
|
"Out": self.prefetch_output_vars,
|
|
|
|
attrs={
|
|
|
|
"RPCClient": rpc_client_var
|
|
|
|
"epmap": eplist,
|
|
|
|
},
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
attrs={"epmap": eplist})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# insert concat_op
|
|
|
|
# insert concat_op
|
|
|
|
program.global_block().insert_op(
|
|
|
|
program.global_block().insert_op(
|
|
|
|
@ -689,8 +692,7 @@ class DistributeTranspiler:
|
|
|
|
# break for loop
|
|
|
|
# break for loop
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var,
|
|
|
|
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
|
|
|
|
pserver_endpoints):
|
|
|
|
|
|
|
|
# 2. add split_ids_op and send_vars_op to send gradient to pservers
|
|
|
|
# 2. add split_ids_op and send_vars_op to send gradient to pservers
|
|
|
|
# there should only be one table_name
|
|
|
|
# there should only be one table_name
|
|
|
|
all_ops = program.global_block().ops
|
|
|
|
all_ops = program.global_block().ops
|
|
|
|
@ -710,9 +712,12 @@ class DistributeTranspiler:
|
|
|
|
index=op_index + 2,
|
|
|
|
index=op_index + 2,
|
|
|
|
type="send_vars",
|
|
|
|
type="send_vars",
|
|
|
|
inputs={'X': self.table_grad_list},
|
|
|
|
inputs={'X': self.table_grad_list},
|
|
|
|
outputs={"RPCClient": rpc_client_var},
|
|
|
|
outputs={},
|
|
|
|
attrs={"sync_send": True,
|
|
|
|
attrs={
|
|
|
|
"epmap": pserver_endpoints})
|
|
|
|
"sync_send": True,
|
|
|
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
|
|
|
})
|
|
|
|
break
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
def _create_prefetch_block(self, pserver_index, pserver_program,
|
|
|
|
def _create_prefetch_block(self, pserver_index, pserver_program,
|
|
|
|
|