|
|
|
@ -273,15 +273,25 @@ class DistributeTranspiler:
|
|
|
|
|
if param_grad[0].name == self.table_name
|
|
|
|
|
][0]
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
self.table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, trainer_id, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
self.trainer_side_table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, trainer_id, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
self.trainer_side_table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.pserver_%d" % (table_grad_var.name, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
|
|
|
|
|
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
|
|
|
|
@ -400,7 +410,8 @@ class DistributeTranspiler:
|
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program, eplist)
|
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program,
|
|
|
|
|
pserver_endpoints)
|
|
|
|
|
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
|
|
|
|
|
|
|
|
|
|
def get_trainer_program(self):
|
|
|
|
@ -537,7 +548,7 @@ class DistributeTranspiler:
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
pserver_index = self.pserver_endpoints.index(endpoint)
|
|
|
|
|
table_opt_block = self._create_table_optimize_block(
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx)
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
|
|
|
|
|
prefetch_block = self._create_prefetch_block(
|
|
|
|
|
pserver_index, pserver_program, table_opt_block)
|
|
|
|
|
|
|
|
|
@ -621,7 +632,8 @@ class DistributeTranspiler:
|
|
|
|
|
return s_prog
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program, eplist):
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program,
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
|
self.prefetch_input_vars = None
|
|
|
|
|
self.prefetch_output_vars = None
|
|
|
|
@ -670,7 +682,7 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={'X': self.prefetch_input_vars},
|
|
|
|
|
outputs={"Out": self.prefetch_output_vars},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": eplist,
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
@ -707,11 +719,11 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': [program.global_block().vars[table_grad_name]]
|
|
|
|
|
},
|
|
|
|
|
outputs={"Out": self.table_grad_list})
|
|
|
|
|
outputs={"Out": self.trainer_side_table_grad_list})
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index + 2,
|
|
|
|
|
type="send_vars",
|
|
|
|
|
inputs={'X': self.table_grad_list},
|
|
|
|
|
inputs={'X': self.trainer_side_table_grad_list},
|
|
|
|
|
outputs={},
|
|
|
|
|
attrs={
|
|
|
|
|
"sync_send": True,
|
|
|
|
@ -750,16 +762,7 @@ class DistributeTranspiler:
|
|
|
|
|
return prefetch_block
|
|
|
|
|
|
|
|
|
|
def _create_table_optimize_block(self, pserver_index, pserver_program,
|
|
|
|
|
pre_block_idx):
|
|
|
|
|
def _clone_var(block, var, persistable=True):
|
|
|
|
|
assert isinstance(var, Variable)
|
|
|
|
|
return block.create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
shape=var.shape,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
type=var.type,
|
|
|
|
|
persistable=persistable)
|
|
|
|
|
|
|
|
|
|
pre_block_idx, grad_to_block_id):
|
|
|
|
|
# STEP: create table optimize block
|
|
|
|
|
# create table param and grad var in pserver program
|
|
|
|
|
origin_param_var = self.origin_program.global_block().vars[
|
|
|
|
@ -770,11 +773,11 @@ class DistributeTranspiler:
|
|
|
|
|
dtype=origin_param_var.dtype,
|
|
|
|
|
type=core.VarDesc.VarType.SELECTED_ROWS,
|
|
|
|
|
persistable=True)
|
|
|
|
|
grad_var = _clone_var(
|
|
|
|
|
pserver_program.global_block(),
|
|
|
|
|
# parameter must be selected rows
|
|
|
|
|
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
|
|
|
|
|
grad_var = pserver_program.global_block().clone_variable(
|
|
|
|
|
self.origin_program.global_block().vars[grad_var_name(
|
|
|
|
|
self.table_name)],
|
|
|
|
|
persistable=False)
|
|
|
|
|
self.table_name)])
|
|
|
|
|
|
|
|
|
|
# create table optimize block in pserver program
|
|
|
|
|
table_opt_op = [
|
|
|
|
@ -788,7 +791,7 @@ class DistributeTranspiler:
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
# create grad vars in pserver program
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
table_grad_list = [
|
|
|
|
|
pserver_side_table_grad_list = [
|
|
|
|
|
pserver_program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, index, pserver_index),
|
|
|
|
@ -798,11 +801,21 @@ class DistributeTranspiler:
|
|
|
|
|
for index in range(self.trainer_num)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# append sum op for table_grad_list
|
|
|
|
|
# append sum op for pserver_side_table_grad_list
|
|
|
|
|
table_opt_block.append_op(
|
|
|
|
|
type="sum",
|
|
|
|
|
inputs={"X": table_grad_list},
|
|
|
|
|
inputs={"X": pserver_side_table_grad_list},
|
|
|
|
|
outputs={"Out": [grad_var]})
|
|
|
|
|
else:
|
|
|
|
|
# in async_mode, for table gradient, it also need to be splited to each parameter server
|
|
|
|
|
origin_grad_name = grad_var.name
|
|
|
|
|
splited_grad_name = self.trainer_side_table_grad_list[
|
|
|
|
|
pserver_index].name
|
|
|
|
|
if not splited_grad_name.startswith(origin_grad_name):
|
|
|
|
|
raise ValueError("origin_grad_var: " + splited_grad_name +
|
|
|
|
|
" grad_var:" + grad_var.name)
|
|
|
|
|
grad_var = pserver_program.global_block().rename_var(
|
|
|
|
|
origin_grad_name, splited_grad_name)
|
|
|
|
|
|
|
|
|
|
lr_var = pserver_program.global_block().vars[table_opt_op.input(
|
|
|
|
|
"LearningRate")[0]]
|
|
|
|
@ -818,6 +831,9 @@ class DistributeTranspiler:
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
attrs=table_opt_op.attrs)
|
|
|
|
|
|
|
|
|
|
# add table parameter gradient and it's block id to grad_to_block_id
|
|
|
|
|
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))
|
|
|
|
|
|
|
|
|
|
return table_opt_block
|
|
|
|
|
|
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
|