|
|
|
@ -515,21 +515,20 @@ class DistributeTranspiler:
|
|
|
|
|
grad_to_block_id, None)
|
|
|
|
|
|
|
|
|
|
# process distributed lookup_table
|
|
|
|
|
prefetch_block = None
|
|
|
|
|
prefetch_var_name_to_block_id = []
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
pserver_index = self.pserver_endpoints.index(endpoint)
|
|
|
|
|
table_opt_block = self._create_table_optimize_block(
|
|
|
|
|
pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
|
|
|
|
|
prefetch_block = self._create_prefetch_block(
|
|
|
|
|
prefetch_var_name_to_block_id = self._create_prefetch_block(
|
|
|
|
|
pserver_index, pserver_program, table_opt_block)
|
|
|
|
|
|
|
|
|
|
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
|
|
|
|
|
# not be executed, so it's safe to use optimize_block to hold the place
|
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
|
assert prefetch_block is not None
|
|
|
|
|
assert len(prefetch_var_name_to_block_id) > 0
|
|
|
|
|
else:
|
|
|
|
|
assert prefetch_block is None
|
|
|
|
|
prefetch_block = pserver_program.global_block()
|
|
|
|
|
assert len(prefetch_var_name_to_block_id) == 0
|
|
|
|
|
|
|
|
|
|
# step5 append the listen_and_serv op
|
|
|
|
|
pserver_program.global_block().append_op(
|
|
|
|
@ -540,7 +539,7 @@ class DistributeTranspiler:
|
|
|
|
|
"OptimizeBlock": pserver_program.block(1),
|
|
|
|
|
"endpoint": endpoint,
|
|
|
|
|
"Fanin": self.trainer_num,
|
|
|
|
|
"PrefetchBlock": prefetch_block,
|
|
|
|
|
"prefetch_var_name_to_block_id": prefetch_var_name_to_block_id,
|
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
"grad_to_block_id": grad_to_block_id
|
|
|
|
|
})
|
|
|
|
@ -608,8 +607,15 @@ class DistributeTranspiler:
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program,
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
|
self.prefetch_input_vars = None
|
|
|
|
|
self.prefetch_output_vars = None
|
|
|
|
|
# self.all_prefetch_input_vars =
|
|
|
|
|
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
|
|
|
|
|
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
|
|
|
|
|
self.all_prefetch_input_vars = []
|
|
|
|
|
|
|
|
|
|
# self.all_prefetch_input_vars =
|
|
|
|
|
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
|
|
|
|
|
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
|
|
|
|
|
self.all_prefetch_output_vars = []
|
|
|
|
|
|
|
|
|
|
continue_search_lookup_table_op = True
|
|
|
|
|
while continue_search_lookup_table_op:
|
|
|
|
@ -623,18 +629,19 @@ class DistributeTranspiler:
|
|
|
|
|
ids_name = op.input("Ids")
|
|
|
|
|
out_name = op.output("Out")
|
|
|
|
|
|
|
|
|
|
if self.prefetch_input_vars is None:
|
|
|
|
|
ids_var = program.global_block().vars[ids_name[0]]
|
|
|
|
|
self.prefetch_input_vars = self.create_splited_vars(
|
|
|
|
|
source_var=ids_var,
|
|
|
|
|
block=program.global_block(),
|
|
|
|
|
tag="_prefetch_in_")
|
|
|
|
|
if self.prefetch_output_vars is None:
|
|
|
|
|
out_var = program.global_block().vars[out_name[0]]
|
|
|
|
|
self.prefetch_output_vars = self.create_splited_vars(
|
|
|
|
|
source_var=out_var,
|
|
|
|
|
block=program.global_block(),
|
|
|
|
|
tag="_prefetch_out_")
|
|
|
|
|
ids_var = program.global_block().vars[ids_name[0]]
|
|
|
|
|
prefetch_input_vars = self.create_splited_vars(
|
|
|
|
|
source_var=ids_var,
|
|
|
|
|
block=program.global_block(),
|
|
|
|
|
tag="_prefetch_in_")
|
|
|
|
|
self.all_prefetch_input_vars.append(prefetch_input_vars)
|
|
|
|
|
|
|
|
|
|
out_var = program.global_block().vars[out_name[0]]
|
|
|
|
|
prefetch_output_vars = self.create_splited_vars(
|
|
|
|
|
source_var=out_var,
|
|
|
|
|
block=program.global_block(),
|
|
|
|
|
tag="_prefetch_out_")
|
|
|
|
|
self.all_prefetch_output_vars.append(prefetch_output_vars)
|
|
|
|
|
|
|
|
|
|
# insert split_ids_op
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
@ -646,14 +653,14 @@ class DistributeTranspiler:
|
|
|
|
|
for varname in ids_name
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
outputs={"Out": self.prefetch_input_vars})
|
|
|
|
|
outputs={"Out": prefetch_input_vars})
|
|
|
|
|
|
|
|
|
|
# insert prefetch_op
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index + 1,
|
|
|
|
|
type="prefetch",
|
|
|
|
|
inputs={'X': self.prefetch_input_vars},
|
|
|
|
|
outputs={"Out": self.prefetch_output_vars},
|
|
|
|
|
inputs={'X': prefetch_input_vars},
|
|
|
|
|
outputs={"Out": prefetch_output_vars},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
@ -663,7 +670,7 @@ class DistributeTranspiler:
|
|
|
|
|
program.global_block().insert_op(
|
|
|
|
|
index=op_index + 2,
|
|
|
|
|
type="concat",
|
|
|
|
|
inputs={'X': self.prefetch_output_vars},
|
|
|
|
|
inputs={'X': prefetch_output_vars},
|
|
|
|
|
outputs={
|
|
|
|
|
"Out": [
|
|
|
|
|
program.global_block().vars[varname]
|
|
|
|
@ -709,30 +716,34 @@ class DistributeTranspiler:
|
|
|
|
|
optimize_block):
|
|
|
|
|
# STEP: create prefetch block
|
|
|
|
|
table_var = pserver_program.global_block().vars[self.table_name]
|
|
|
|
|
prefetch_block = pserver_program.create_block(optimize_block.idx)
|
|
|
|
|
trainer_ids = self.prefetch_input_vars[pserver_index]
|
|
|
|
|
pserver_ids = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_ids.name,
|
|
|
|
|
type=trainer_ids.type,
|
|
|
|
|
shape=trainer_ids.shape,
|
|
|
|
|
dtype=trainer_ids.dtype)
|
|
|
|
|
trainer_out = self.prefetch_output_vars[pserver_index]
|
|
|
|
|
pserver_out = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_out.name,
|
|
|
|
|
type=trainer_out.type,
|
|
|
|
|
shape=trainer_out.shape,
|
|
|
|
|
dtype=trainer_out.dtype)
|
|
|
|
|
prefetch_block.append_op(
|
|
|
|
|
type="lookup_sparse_table",
|
|
|
|
|
inputs={'Ids': pserver_ids,
|
|
|
|
|
"W": table_var},
|
|
|
|
|
outputs={"Out": pserver_out},
|
|
|
|
|
attrs={
|
|
|
|
|
"is_sparse": True, # has no effect on lookup_table op
|
|
|
|
|
"is_distributed": True,
|
|
|
|
|
"padding_idx": -1
|
|
|
|
|
})
|
|
|
|
|
return prefetch_block
|
|
|
|
|
prefetch_var_name_to_block_id = []
|
|
|
|
|
for index in range(len(self.all_prefetch_input_vars)):
|
|
|
|
|
prefetch_block = pserver_program.create_block(optimize_block.idx)
|
|
|
|
|
trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
|
|
|
|
|
pserver_ids = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_ids.name,
|
|
|
|
|
type=trainer_ids.type,
|
|
|
|
|
shape=trainer_ids.shape,
|
|
|
|
|
dtype=trainer_ids.dtype)
|
|
|
|
|
trainer_out = self.all_prefetch_output_vars[index][pserver_index]
|
|
|
|
|
pserver_out = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_out.name,
|
|
|
|
|
type=trainer_out.type,
|
|
|
|
|
shape=trainer_out.shape,
|
|
|
|
|
dtype=trainer_out.dtype)
|
|
|
|
|
prefetch_block.append_op(
|
|
|
|
|
type="lookup_sparse_table",
|
|
|
|
|
inputs={'Ids': pserver_ids,
|
|
|
|
|
"W": table_var},
|
|
|
|
|
outputs={"Out": pserver_out},
|
|
|
|
|
attrs={
|
|
|
|
|
"is_sparse": True, # has no effect on lookup_table op
|
|
|
|
|
"is_distributed": True,
|
|
|
|
|
"padding_idx": -1
|
|
|
|
|
})
|
|
|
|
|
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
|
|
|
|
|
prefetch_block.idx))
|
|
|
|
|
return prefetch_var_name_to_block_id
|
|
|
|
|
|
|
|
|
|
def _create_table_optimize_block(self, pserver_index, pserver_program,
|
|
|
|
|
pre_block_idx, grad_to_block_id):
|
|
|
|
|