|
|
|
@ -256,15 +256,25 @@ class DistributeTranspiler:
|
|
|
|
|
if param_grad[0].name == self.table_name
|
|
|
|
|
][0]
|
|
|
|
|
table_grad_var = self.table_param_grad[1]
|
|
|
|
|
self.table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, trainer_id, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
if self.sync_mode:
|
|
|
|
|
self.table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.trainer_%d.pserver_%d" %
|
|
|
|
|
(table_grad_var.name, trainer_id, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
self.table_grad_list = [
|
|
|
|
|
program.global_block().create_var(
|
|
|
|
|
name="%s.pserver_%d" % (table_grad_var.name, index),
|
|
|
|
|
type=table_grad_var.type,
|
|
|
|
|
shape=table_grad_var.shape,
|
|
|
|
|
dtype=table_grad_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
|
|
|
|
|
param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
|
|
|
|
@ -328,7 +338,7 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
|
|
|
|
|
eplist)
|
|
|
|
|
pserver_endpoints)
|
|
|
|
|
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
|
|
|
|
|
pserver_endpoints)
|
|
|
|
|
|
|
|
|
@ -551,7 +561,7 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
# transpiler function for dis lookup_table
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
|
|
|
|
|
eplist):
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
|
self.prefetch_input_vars = None
|
|
|
|
|
self.prefetch_output_vars = None
|
|
|
|
@ -602,7 +612,7 @@ class DistributeTranspiler:
|
|
|
|
|
"Out": self.prefetch_output_vars,
|
|
|
|
|
"RPCClient": rpc_client_var
|
|
|
|
|
},
|
|
|
|
|
attrs={"epmap": eplist})
|
|
|
|
|
attrs={"epmap": pserver_endpoints})
|
|
|
|
|
|
|
|
|
|
# insert concat_op
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
@ -731,6 +741,12 @@ class DistributeTranspiler:
|
|
|
|
|
type="sum",
|
|
|
|
|
inputs={"X": table_grad_list},
|
|
|
|
|
outputs={"Out": [grad_var]})
|
|
|
|
|
else:
|
|
|
|
|
# in async_mode, for table gradient, it also need to be splited to each parameter server
|
|
|
|
|
old_name = grad_var.name
|
|
|
|
|
new_name = old_name + ".pserver_" + str(pserver_index)
|
|
|
|
|
grad_var = pserver_program.global_block().rename_var(old_name,
|
|
|
|
|
new_name)
|
|
|
|
|
|
|
|
|
|
lr_var = pserver_program.global_block().vars[table_opt_op.input(
|
|
|
|
|
"LearningRate")[0]]
|
|
|
|
|