|
|
@ -25,6 +25,8 @@ LOOKUP_TABLE_TYPE = "lookup_table"
|
|
|
|
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
|
|
|
|
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
|
|
|
|
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
|
|
|
|
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GLOBAL_BLOCK_IDX = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VarBlock:
|
|
|
|
class VarBlock:
|
|
|
|
def __init__(self, varname, offset, size):
|
|
|
|
def __init__(self, varname, offset, size):
|
|
|
@ -368,8 +370,8 @@ class DistributeTranspiler:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
recv_inputs.append(single_trainer_var)
|
|
|
|
recv_inputs.append(single_trainer_var)
|
|
|
|
|
|
|
|
|
|
|
|
# step3
|
|
|
|
optimize_block = None
|
|
|
|
optimize_block = pserver_program.create_block(0)
|
|
|
|
|
|
|
|
# step 4
|
|
|
|
# step 4
|
|
|
|
# Create a union-find data structure from optimize ops,
|
|
|
|
# Create a union-find data structure from optimize ops,
|
|
|
|
# If two ops are connected, we could add these two ops
|
|
|
|
# If two ops are connected, we could add these two ops
|
|
|
@ -415,29 +417,34 @@ class DistributeTranspiler:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
|
|
|
|
|
|
|
|
append_block = optimize_block
|
|
|
|
|
|
|
|
# append lr decay ops to the child block if exists
|
|
|
|
# append lr decay ops to the child block if exists
|
|
|
|
|
|
|
|
lr_decay_block = None
|
|
|
|
lr_ops = self._get_lr_ops()
|
|
|
|
lr_ops = self._get_lr_ops()
|
|
|
|
if len(lr_ops) > 0:
|
|
|
|
if len(lr_ops) > 0:
|
|
|
|
|
|
|
|
lr_decay_block = pserver_program.create_block(GLOBAL_BLOCK_IDX)
|
|
|
|
for _, op in enumerate(lr_ops):
|
|
|
|
for _, op in enumerate(lr_ops):
|
|
|
|
self._append_pserver_non_opt_ops(append_block, op)
|
|
|
|
self._append_pserver_non_opt_ops(lr_decay_block, op)
|
|
|
|
|
|
|
|
|
|
|
|
append_block = pserver_program.create_block(append_block.idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# append op to the current block
|
|
|
|
# append op to the current block
|
|
|
|
per_opt_block = append_block
|
|
|
|
per_opt_block = None
|
|
|
|
|
|
|
|
pre_block_idx = GLOBAL_BLOCK_IDX
|
|
|
|
|
|
|
|
if lr_decay_block is not None:
|
|
|
|
|
|
|
|
pre_block_idx = lr_decay_block.idx
|
|
|
|
for idx, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
for idx, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
|
|
|
per_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
|
|
|
if optimize_block is None:
|
|
|
|
|
|
|
|
optimize_block = per_opt_block
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
# optimizer is connected to itself
|
|
|
|
# optimizer is connected to itself
|
|
|
|
if ufind.is_connected(op, opt_op) and \
|
|
|
|
if ufind.is_connected(op, opt_op) and op not in global_ops:
|
|
|
|
op not in global_ops:
|
|
|
|
|
|
|
|
__append_optimize_op__(op, per_opt_block)
|
|
|
|
__append_optimize_op__(op, per_opt_block)
|
|
|
|
if idx == len(opt_op_on_pserver) - 1 and global_ops:
|
|
|
|
|
|
|
|
per_opt_block = pserver_program.create_block(append_block.idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# append global ops
|
|
|
|
# append global ops
|
|
|
|
|
|
|
|
opt_state_block = None
|
|
|
|
|
|
|
|
if global_ops:
|
|
|
|
|
|
|
|
opt_state_block = pserver_program.create_block(per_opt_block.idx)
|
|
|
|
for glb_op in global_ops:
|
|
|
|
for glb_op in global_ops:
|
|
|
|
__append_optimize_op__(glb_op, per_opt_block)
|
|
|
|
__append_optimize_op__(glb_op, opt_state_block)
|
|
|
|
|
|
|
|
|
|
|
|
# NOT USED: single block version:
|
|
|
|
# NOT USED: single block version:
|
|
|
|
#
|
|
|
|
#
|
|
|
@ -451,10 +458,11 @@ class DistributeTranspiler:
|
|
|
|
prefetch_block = None
|
|
|
|
prefetch_block = None
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
if self.has_distributed_lookup_table:
|
|
|
|
pserver_index = self.pserver_endpoints.index(endpoint)
|
|
|
|
pserver_index = self.pserver_endpoints.index(endpoint)
|
|
|
|
self._create_table_optimize_block(pserver_index, pserver_program,
|
|
|
|
table_opt_block = self._create_table_optimize_block(
|
|
|
|
append_block)
|
|
|
|
pserver_index, pserver_program, opt_state_block or
|
|
|
|
|
|
|
|
pserver_program.global_block())
|
|
|
|
prefetch_block = self._create_prefetch_block(
|
|
|
|
prefetch_block = self._create_prefetch_block(
|
|
|
|
pserver_index, pserver_program, optimize_block)
|
|
|
|
pserver_index, pserver_program, table_opt_block)
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
|
|
|
|
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
|
|
|
|
# not be executed, so it's safe to use optimize_block to hold the place
|
|
|
|
# not be executed, so it's safe to use optimize_block to hold the place
|
|
|
@ -724,6 +732,8 @@ class DistributeTranspiler:
|
|
|
|
outputs=outputs,
|
|
|
|
outputs=outputs,
|
|
|
|
attrs=table_opt_op.attrs)
|
|
|
|
attrs=table_opt_op.attrs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return table_opt_block
|
|
|
|
|
|
|
|
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
def _create_vars_from_blocklist(self,
|
|
|
|
def _create_vars_from_blocklist(self,
|
|
|
|
program,
|
|
|
|
program,
|
|
|
|