|
|
|
@ -302,7 +302,6 @@ class DistributeTranspiler(object):
|
|
|
|
|
"""
|
|
|
|
|
# remove optimize ops and add a send op to main_program
|
|
|
|
|
delete_ops(self.origin_program.global_block(), self.optimize_ops)
|
|
|
|
|
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
|
|
|
|
|
self.origin_program.__str__()
|
|
|
|
|
return self.origin_program
|
|
|
|
|
|
|
|
|
@ -383,11 +382,12 @@ class DistributeTranspiler(object):
|
|
|
|
|
if self._is_adam_connected_op(op):
|
|
|
|
|
global_ops.append(op)
|
|
|
|
|
|
|
|
|
|
def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
|
|
|
|
|
def __append_optimize_op__(op, block, grad_to_block_id, merged_var,
|
|
|
|
|
lr_ops):
|
|
|
|
|
if self._is_optimizer_op(op):
|
|
|
|
|
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
|
|
|
|
|
self.origin_program, merged_var)
|
|
|
|
|
else:
|
|
|
|
|
elif op not in lr_ops:
|
|
|
|
|
self._append_pserver_non_opt_ops(block, op)
|
|
|
|
|
|
|
|
|
|
def __op_have_grad_input__(op):
|
|
|
|
@ -452,7 +452,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
# optimizer is connected to itself
|
|
|
|
|
if ufind.is_connected(op, opt_op) and op not in global_ops:
|
|
|
|
|
__append_optimize_op__(op, per_opt_block, grad_to_block_id,
|
|
|
|
|
merged_var)
|
|
|
|
|
merged_var, lr_ops)
|
|
|
|
|
|
|
|
|
|
# append global ops
|
|
|
|
|
if global_ops:
|
|
|
|
@ -461,7 +461,7 @@ class DistributeTranspiler(object):
|
|
|
|
|
optimize_blocks.append(opt_state_block)
|
|
|
|
|
for glb_op in global_ops:
|
|
|
|
|
__append_optimize_op__(glb_op, opt_state_block,
|
|
|
|
|
grad_to_block_id, None)
|
|
|
|
|
grad_to_block_id, None, lr_ops)
|
|
|
|
|
|
|
|
|
|
# process distributed lookup_table
|
|
|
|
|
prefetch_var_name_to_block_id = []
|
|
|
|
|