|
|
|
@ -371,22 +371,18 @@ class DistributeTranspiler:
|
|
|
|
|
recv_inputs.append(single_trainer_var)
|
|
|
|
|
|
|
|
|
|
# step 3
|
|
|
|
|
# each optimization op will has a optimize block
|
|
|
|
|
optimize_block = None
|
|
|
|
|
|
|
|
|
|
# step 4
|
|
|
|
|
# Create a union-find data structure from optimize ops,
|
|
|
|
|
# If two ops are connected, we could add these two ops
|
|
|
|
|
# into one set.
|
|
|
|
|
ufind = self._create_ufind(self.optimize_ops)
|
|
|
|
|
# step 4.2
|
|
|
|
|
# step 3.2
|
|
|
|
|
# Iterate through the ops and append optimize op which
|
|
|
|
|
# located on current pserver
|
|
|
|
|
opt_op_on_pserver = []
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
|
|
|
|
|
opt_op_on_pserver.append(op)
|
|
|
|
|
# step 4.3
|
|
|
|
|
# step 3.3
|
|
|
|
|
# Iterate through the ops, and if an op and the optimize ops
|
|
|
|
|
# which located on current pserver are in one set, then
|
|
|
|
|
# append it into the sub program.
|
|
|
|
@ -420,23 +416,17 @@ class DistributeTranspiler:
|
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
|
|
|
|
|
|
# append lr decay ops to the child block if exists
|
|
|
|
|
lr_decay_block = None
|
|
|
|
|
lr_ops = self._get_lr_ops()
|
|
|
|
|
if len(lr_ops) > 0:
|
|
|
|
|
lr_decay_block = pserver_program.create_block(GLOBAL_BLOCK_IDX)
|
|
|
|
|
lr_decay_block = pserver_program.create_block(
|
|
|
|
|
pserver_program.num_blocks - 1)
|
|
|
|
|
for _, op in enumerate(lr_ops):
|
|
|
|
|
self._append_pserver_non_opt_ops(lr_decay_block, op)
|
|
|
|
|
|
|
|
|
|
# append op to the current block
|
|
|
|
|
per_opt_block = None
|
|
|
|
|
pre_block_idx = GLOBAL_BLOCK_IDX
|
|
|
|
|
if lr_decay_block is not None:
|
|
|
|
|
pre_block_idx = lr_decay_block.idx
|
|
|
|
|
pre_block_idx = pserver_program.num_blocks - 1
|
|
|
|
|
for idx, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
per_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
if optimize_block is None:
|
|
|
|
|
# first optimize block
|
|
|
|
|
optimize_block = per_opt_block
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if ufind.is_connected(op, opt_op) and op not in global_ops:
|
|
|
|
@ -445,9 +435,10 @@ class DistributeTranspiler:
|
|
|
|
|
# append global ops
|
|
|
|
|
opt_state_block = None
|
|
|
|
|
if global_ops:
|
|
|
|
|
opt_state_block = pserver_program.create_block(per_opt_block.idx)
|
|
|
|
|
for glb_op in global_ops:
|
|
|
|
|
__append_optimize_op__(glb_op, opt_state_block)
|
|
|
|
|
opt_state_block = pserver_program.create_block(
|
|
|
|
|
pserver_program.num_blocks - 1)
|
|
|
|
|
for glb_op in global_ops:
|
|
|
|
|
__append_optimize_op__(glb_op, opt_state_block)
|
|
|
|
|
|
|
|
|
|
# NOT USED: single block version:
|
|
|
|
|
#
|
|
|
|
@ -481,7 +472,7 @@ class DistributeTranspiler:
|
|
|
|
|
inputs={'X': recv_inputs},
|
|
|
|
|
outputs={},
|
|
|
|
|
attrs={
|
|
|
|
|
"OptimizeBlock": optimize_block,
|
|
|
|
|
"OptimizeBlock": pserver_program.block(1),
|
|
|
|
|
"endpoint": endpoint,
|
|
|
|
|
"Fanin": self.trainer_num,
|
|
|
|
|
"PrefetchBlock": prefetch_block
|
|
|
|
|