|
|
|
@ -317,8 +317,7 @@ class DistributeTranspiler:
|
|
|
|
|
|
|
|
|
|
def get_trainer_program(self):
|
|
|
|
|
# remove optimize ops and add a send op to main_program
|
|
|
|
|
self.origin_program.global_block().delete_ops(self.optimize_ops)
|
|
|
|
|
self.origin_program.sync_with_cpp()
|
|
|
|
|
self.delete_ops(self.origin_program.global_block(), self.optimize_ops)
|
|
|
|
|
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
|
|
|
|
|
self.origin_program.__str__()
|
|
|
|
|
return self.origin_program
|
|
|
|
@ -602,8 +601,7 @@ class DistributeTranspiler:
|
|
|
|
|
attrs={"axis": 0})
|
|
|
|
|
|
|
|
|
|
# delete lookup_table_op
|
|
|
|
|
program.global_block().delete_ops([op])
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
self.delete_ops(program.global_block(), [op])
|
|
|
|
|
# break for loop
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
@ -1166,3 +1164,12 @@ class DistributeTranspiler:
|
|
|
|
|
in_name.startswith("beta2_pow_acc"):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def delete_ops(self, block, ops):
|
|
|
|
|
try:
|
|
|
|
|
start = list(block.ops).index(ops[0])
|
|
|
|
|
end = list(block.ops).index(ops[-1])
|
|
|
|
|
[block.remove_op(start) for _ in xrange(end - start + 1)]
|
|
|
|
|
except Exception, e:
|
|
|
|
|
raise e
|
|
|
|
|
block.program.sync_with_cpp()
|
|
|
|
|