|
|
|
@ -713,7 +713,7 @@ in a single call.")
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if op.attr(OP_ROLE_VAR_ATTR_NAME)[0] == optimize_target_param_name and \
|
|
|
|
|
op not in global_ops:
|
|
|
|
|
op not in global_ops:
|
|
|
|
|
log("append opt op: ", op.type, op.input_arg_names,
|
|
|
|
|
merged_var)
|
|
|
|
|
__append_optimize_op__(op, per_opt_block,
|
|
|
|
@ -1034,15 +1034,11 @@ to transpile() call.")
|
|
|
|
|
def _replace_lookup_table_op_with_prefetch(self, program,
|
|
|
|
|
pserver_endpoints):
|
|
|
|
|
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
|
|
|
|
|
# self.all_prefetch_input_vars =
|
|
|
|
|
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
|
|
|
|
|
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
|
|
|
|
|
self.all_in_ids_vars = []
|
|
|
|
|
self.all_prefetch_input_vars = []
|
|
|
|
|
|
|
|
|
|
# self.all_prefetch_input_vars =
|
|
|
|
|
# [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
|
|
|
|
|
# [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
|
|
|
|
|
self.all_prefetch_output_vars = []
|
|
|
|
|
self.all_out_emb_vars = []
|
|
|
|
|
lookup_table_op_index = -1
|
|
|
|
|
|
|
|
|
|
continue_search_lookup_table_op = True
|
|
|
|
|
while continue_search_lookup_table_op:
|
|
|
|
@ -1052,72 +1048,68 @@ to transpile() call.")
|
|
|
|
|
if op.type == LOOKUP_TABLE_TYPE:
|
|
|
|
|
continue_search_lookup_table_op = True
|
|
|
|
|
|
|
|
|
|
lookup_table_op_index = list(all_ops).index(op)
|
|
|
|
|
lookup_table_op_index = lookup_table_op_index if lookup_table_op_index != -1 else list(
|
|
|
|
|
all_ops).index(op)
|
|
|
|
|
ids_name = op.input("Ids")
|
|
|
|
|
out_name = op.output("Out")
|
|
|
|
|
|
|
|
|
|
ids_var = program.global_block().vars[ids_name[0]]
|
|
|
|
|
prefetch_input_vars = self._create_splited_vars(
|
|
|
|
|
source_var=ids_var,
|
|
|
|
|
block=program.global_block(),
|
|
|
|
|
tag="_prefetch_in_")
|
|
|
|
|
self.all_prefetch_input_vars.append(prefetch_input_vars)
|
|
|
|
|
self.all_in_ids_vars.append(ids_var)
|
|
|
|
|
|
|
|
|
|
out_var = program.global_block().vars[out_name[0]]
|
|
|
|
|
prefetch_output_vars = self._create_splited_vars(
|
|
|
|
|
source_var=out_var,
|
|
|
|
|
block=program.global_block(),
|
|
|
|
|
tag="_prefetch_out_")
|
|
|
|
|
self.all_prefetch_output_vars.append(prefetch_output_vars)
|
|
|
|
|
|
|
|
|
|
# insert split_ids_op
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=lookup_table_op_index,
|
|
|
|
|
type="split_ids",
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': [
|
|
|
|
|
program.global_block().vars[varname]
|
|
|
|
|
for varname in ids_name
|
|
|
|
|
]
|
|
|
|
|
},
|
|
|
|
|
outputs={"Out": prefetch_input_vars})
|
|
|
|
|
|
|
|
|
|
# insert prefetch_op
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=lookup_table_op_index + 1,
|
|
|
|
|
type="prefetch",
|
|
|
|
|
inputs={'X': prefetch_input_vars},
|
|
|
|
|
outputs={"Out": prefetch_output_vars},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|
# FIXME(qiao) temporarily disable this config because prefetch
|
|
|
|
|
# is not act as other rpc op, it's more like a forward op
|
|
|
|
|
# RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# insert concat_op
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=lookup_table_op_index + 2,
|
|
|
|
|
type="merge_ids",
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': [
|
|
|
|
|
program.global_block().vars[varname]
|
|
|
|
|
for varname in ids_name
|
|
|
|
|
],
|
|
|
|
|
'X': prefetch_output_vars
|
|
|
|
|
},
|
|
|
|
|
outputs={
|
|
|
|
|
"Out": [
|
|
|
|
|
program.global_block().vars[varname]
|
|
|
|
|
for varname in out_name
|
|
|
|
|
]
|
|
|
|
|
})
|
|
|
|
|
self.all_out_emb_vars.append(out_var)
|
|
|
|
|
|
|
|
|
|
# delete lookup_table_op
|
|
|
|
|
delete_ops(program.global_block(), [op])
|
|
|
|
|
# break for loop
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
for index in range(len(self.pserver_endpoints)):
|
|
|
|
|
in_var = program.global_block().create_var(
|
|
|
|
|
name=str("prefetch_compress_in_tmp_" + str(index)),
|
|
|
|
|
type=self.all_in_ids_vars[0].type,
|
|
|
|
|
shape=self.all_in_ids_vars[0].shape,
|
|
|
|
|
dtype=self.all_in_ids_vars[0].dtype)
|
|
|
|
|
self.all_prefetch_input_vars.append(in_var)
|
|
|
|
|
|
|
|
|
|
out_var = program.global_block().create_var(
|
|
|
|
|
name=str("prefetch_compress_out_tmp_" + str(index)),
|
|
|
|
|
type=self.all_out_emb_vars[0].type,
|
|
|
|
|
shape=self.all_out_emb_vars[0].shape,
|
|
|
|
|
dtype=self.all_out_emb_vars[0].dtype)
|
|
|
|
|
self.all_prefetch_output_vars.append(out_var)
|
|
|
|
|
|
|
|
|
|
# insert split_ids_op
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=lookup_table_op_index,
|
|
|
|
|
type="split_ids",
|
|
|
|
|
inputs={'Ids': self.all_in_ids_vars},
|
|
|
|
|
outputs={"Out": self.all_prefetch_input_vars})
|
|
|
|
|
|
|
|
|
|
# insert prefetch_op
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=lookup_table_op_index + 1,
|
|
|
|
|
type="prefetch",
|
|
|
|
|
inputs={'X': self.all_prefetch_input_vars},
|
|
|
|
|
outputs={"Out": self.all_prefetch_output_vars},
|
|
|
|
|
attrs={
|
|
|
|
|
"epmap": pserver_endpoints,
|
|
|
|
|
# FIXME(qiao) temporarily disable this config because prefetch
|
|
|
|
|
# is not act as other rpc op, it's more like a forward op
|
|
|
|
|
# RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
# insert concat_op
|
|
|
|
|
program.global_block()._insert_op(
|
|
|
|
|
index=lookup_table_op_index + 2,
|
|
|
|
|
type="merge_ids",
|
|
|
|
|
inputs={
|
|
|
|
|
'Ids': self.all_in_ids_vars,
|
|
|
|
|
'Rows': self.all_prefetch_input_vars,
|
|
|
|
|
'X': self.all_prefetch_output_vars
|
|
|
|
|
},
|
|
|
|
|
outputs={"Out": self.all_out_emb_vars})
|
|
|
|
|
|
|
|
|
|
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
|
|
|
|
|
# 2. add split_ids_op and send_op to send gradient to pservers
|
|
|
|
|
|
|
|
|
@ -1160,32 +1152,31 @@ to transpile() call.")
|
|
|
|
|
# STEP: create prefetch block
|
|
|
|
|
table_var = pserver_program.global_block().vars[self.table_name]
|
|
|
|
|
prefetch_var_name_to_block_id = []
|
|
|
|
|
for index in range(len(self.all_prefetch_input_vars)):
|
|
|
|
|
prefetch_block = pserver_program._create_block(optimize_block.idx)
|
|
|
|
|
trainer_ids = self.all_prefetch_input_vars[index][pserver_index]
|
|
|
|
|
pserver_ids = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_ids.name,
|
|
|
|
|
type=trainer_ids.type,
|
|
|
|
|
shape=trainer_ids.shape,
|
|
|
|
|
dtype=trainer_ids.dtype)
|
|
|
|
|
trainer_out = self.all_prefetch_output_vars[index][pserver_index]
|
|
|
|
|
pserver_out = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_out.name,
|
|
|
|
|
type=trainer_out.type,
|
|
|
|
|
shape=trainer_out.shape,
|
|
|
|
|
dtype=trainer_out.dtype)
|
|
|
|
|
prefetch_block.append_op(
|
|
|
|
|
type="lookup_sparse_table",
|
|
|
|
|
inputs={'Ids': pserver_ids,
|
|
|
|
|
"W": table_var},
|
|
|
|
|
outputs={"Out": pserver_out},
|
|
|
|
|
attrs={
|
|
|
|
|
"is_sparse": True, # has no effect on lookup_table op
|
|
|
|
|
"is_distributed": True,
|
|
|
|
|
"padding_idx": -1
|
|
|
|
|
})
|
|
|
|
|
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
|
|
|
|
|
prefetch_block.idx))
|
|
|
|
|
prefetch_block = pserver_program._create_block(optimize_block.idx)
|
|
|
|
|
trainer_ids = self.all_prefetch_input_vars[pserver_index]
|
|
|
|
|
pserver_ids = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_ids.name,
|
|
|
|
|
type=trainer_ids.type,
|
|
|
|
|
shape=trainer_ids.shape,
|
|
|
|
|
dtype=trainer_ids.dtype)
|
|
|
|
|
trainer_out = self.all_prefetch_output_vars[pserver_index]
|
|
|
|
|
pserver_out = pserver_program.global_block().create_var(
|
|
|
|
|
name=trainer_out.name,
|
|
|
|
|
type=trainer_out.type,
|
|
|
|
|
shape=trainer_out.shape,
|
|
|
|
|
dtype=trainer_out.dtype)
|
|
|
|
|
prefetch_block.append_op(
|
|
|
|
|
type="lookup_sparse_table",
|
|
|
|
|
inputs={'Ids': pserver_ids,
|
|
|
|
|
"W": table_var},
|
|
|
|
|
outputs={"Out": pserver_out},
|
|
|
|
|
attrs={
|
|
|
|
|
"is_sparse": True, # has no effect on lookup_table op
|
|
|
|
|
"is_distributed": True,
|
|
|
|
|
"padding_idx": -1
|
|
|
|
|
})
|
|
|
|
|
prefetch_var_name_to_block_id.append(trainer_ids.name + ":" + str(
|
|
|
|
|
prefetch_block.idx))
|
|
|
|
|
return prefetch_var_name_to_block_id
|
|
|
|
|
|
|
|
|
|
def _create_table_optimize_block(self, pserver_index, pserver_program,
|
|
|
|
@ -1364,16 +1355,6 @@ to transpile() call.")
|
|
|
|
|
program.global_block()._sync_with_cpp()
|
|
|
|
|
return var_mapping
|
|
|
|
|
|
|
|
|
|
def _create_splited_vars(self, source_var, block, tag):
|
|
|
|
|
return [
|
|
|
|
|
block.create_var(
|
|
|
|
|
name=str(source_var.name + tag + str(index)),
|
|
|
|
|
type=source_var.type,
|
|
|
|
|
shape=source_var.shape,
|
|
|
|
|
dtype=source_var.dtype)
|
|
|
|
|
for index in range(len(self.pserver_endpoints))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def _clone_var(self, block, var, persistable=True):
|
|
|
|
|
return block.create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|