|
|
|
@ -476,7 +476,7 @@ class DistributeTranspiler:
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
pserver_index = self.pserver_endpoints.index(endpoint)
|
|
|
|
|
table_opt_block = self._create_table_optimize_block(
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx)
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
|
|
|
|
|
prefetch_block = self._create_prefetch_block(
|
|
|
|
|
pserver_index, pserver_program, table_opt_block)
|
|
|
|
|
|
|
|
|
@ -688,7 +688,7 @@ class DistributeTranspiler:
|
|
|
|
|
return prefetch_block
|
|
|
|
|
|
|
|
|
|
def _create_table_optimize_block(self, pserver_index, pserver_program,
|
|
|
|
|
pre_block_idx):
|
|
|
|
|
pre_block_idx, grad_to_block_id):
|
|
|
|
|
def _clone_var(block, var, persistable=True):
|
|
|
|
|
assert isinstance(var, Variable)
|
|
|
|
|
return block.create_var(
|
|
|
|
@ -743,10 +743,13 @@ class DistributeTranspiler:
|
|
|
|
|
outputs={"Out": [grad_var]})
|
|
|
|
|
else:
|
|
|
|
|
# in async_mode, for table gradient, it also need to be splited to each parameter server
|
|
|
|
|
old_name = grad_var.name
|
|
|
|
|
new_name = old_name + ".pserver_" + str(pserver_index)
|
|
|
|
|
grad_var = pserver_program.global_block().rename_var(old_name,
|
|
|
|
|
new_name)
|
|
|
|
|
origin_grad_name = grad_var.name
|
|
|
|
|
splited_grad_name = self.table_grad_list[pserver_index].name
|
|
|
|
|
if not splited_grad_name.startswith(origin_grad_name):
|
|
|
|
|
raise ValueError("origin_grad_var: " + splited_grad_name +
|
|
|
|
|
" grad_var:" + grad_var.name)
|
|
|
|
|
grad_var = pserver_program.global_block().rename_var(
|
|
|
|
|
origin_grad_name, splited_grad_name)
|
|
|
|
|
|
|
|
|
|
lr_var = pserver_program.global_block().vars[table_opt_op.input(
|
|
|
|
|
"LearningRate")[0]]
|
|
|
|
@ -762,6 +765,9 @@ class DistributeTranspiler:
|
|
|
|
|
outputs=outputs,
|
|
|
|
|
attrs=table_opt_op.attrs)
|
|
|
|
|
|
|
|
|
|
# add table parameter gradient and it's block id to grad_to_block_id
|
|
|
|
|
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))
|
|
|
|
|
|
|
|
|
|
return table_opt_block
|
|
|
|
|
|
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
|