|
|
|
@ -408,9 +408,9 @@ class DistributeTranspiler:
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
|
|
|
|
|
def __append_optimize_op__(op, block):
|
|
|
|
|
def __append_optimize_op__(op, block, grad_to_block_id):
|
|
|
|
|
if self._is_opt_op(op):
|
|
|
|
|
self._append_pserver_ops(block, op, endpoint,
|
|
|
|
|
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
|
|
|
|
|
default_main_program())
|
|
|
|
|
else:
|
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
@ -424,13 +424,14 @@ class DistributeTranspiler:
|
|
|
|
|
self._append_pserver_non_opt_ops(lr_decay_block, op)
|
|
|
|
|
|
|
|
|
|
# append op to the current block
|
|
|
|
|
grad_to_block_id = []
|
|
|
|
|
pre_block_idx = pserver_program.num_blocks - 1
|
|
|
|
|
for idx, opt_op in enumerate(opt_op_on_pserver):
|
|
|
|
|
per_opt_block = pserver_program.create_block(pre_block_idx)
|
|
|
|
|
for _, op in enumerate(self.optimize_ops):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if ufind.is_connected(op, opt_op) and op not in global_ops:
|
|
|
|
|
__append_optimize_op__(op, per_opt_block)
|
|
|
|
|
__append_optimize_op__(op, per_opt_block, grad_to_block_id)
|
|
|
|
|
|
|
|
|
|
# append global ops
|
|
|
|
|
opt_state_block = None
|
|
|
|
@ -476,7 +477,7 @@ class DistributeTranspiler:
|
|
|
|
|
"Fanin": self.trainer_num,
|
|
|
|
|
"PrefetchBlock": prefetch_block,
|
|
|
|
|
"sync_mode": self.sync_mode,
|
|
|
|
|
"grad_to_id": []
|
|
|
|
|
"grad_to_id": grad_to_block_id
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
pserver_program.sync_with_cpp()
|
|
|
|
@ -883,7 +884,7 @@ class DistributeTranspiler:
|
|
|
|
|
return orig_var_name
|
|
|
|
|
|
|
|
|
|
def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
|
|
|
|
|
origin_program):
|
|
|
|
|
grad_to_block_id, origin_program):
|
|
|
|
|
program = optimize_block.program
|
|
|
|
|
pserver_block = program.global_block()
|
|
|
|
|
new_inputs = dict()
|
|
|
|
@ -904,6 +905,8 @@ class DistributeTranspiler:
|
|
|
|
|
return
|
|
|
|
|
merged_var = \
|
|
|
|
|
pserver_block.vars[self._orig_varname(grad_block.name)]
|
|
|
|
|
grad_to_block_id.append(merged_var.name + ":" + str(
|
|
|
|
|
optimize_block.idx))
|
|
|
|
|
if self.trainer_num > 1:
|
|
|
|
|
vars2merge = []
|
|
|
|
|
for i in xrange(self.trainer_num):
|
|
|
|
|