|
|
|
@ -22,6 +22,15 @@ from paddle.fluid.framework import Program, program_guard
|
|
|
|
|
from paddle.fluid.transpiler import memory_optimize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_vars(prog):
|
|
|
|
|
assert (isinstance(prog, Program))
|
|
|
|
|
all_vars = set()
|
|
|
|
|
for op in prog.global_block().ops:
|
|
|
|
|
all_vars.update(op.input_arg_names)
|
|
|
|
|
all_vars.update(op.output_arg_names)
|
|
|
|
|
return all_vars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestControlFlowGraph(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
program = Program()
|
|
|
|
@ -37,11 +46,11 @@ class TestControlFlowGraph(unittest.TestCase):
|
|
|
|
|
self.program = program
|
|
|
|
|
|
|
|
|
|
def test_control_flow_graph(self):
|
|
|
|
|
print("before optimization")
|
|
|
|
|
print(str(self.program))
|
|
|
|
|
result_program = memory_optimize(self.program)
|
|
|
|
|
print("after optimization")
|
|
|
|
|
print(str(result_program))
|
|
|
|
|
result_program = self.program.clone()
|
|
|
|
|
memory_optimize(self.program)
|
|
|
|
|
old_vars = _get_vars(self.program)
|
|
|
|
|
new_vars = _get_vars(result_program)
|
|
|
|
|
self.assertTrue(old_vars != new_vars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMemoryTranspiler2(unittest.TestCase):
|
|
|
|
@ -58,14 +67,22 @@ class TestMemoryTranspiler2(unittest.TestCase):
|
|
|
|
|
avg_cost = layers.mean(cost)
|
|
|
|
|
opt = optimizer.SGD(learning_rate=0.001)
|
|
|
|
|
opt.minimize(avg_cost)
|
|
|
|
|
self.skip_set = set([cost.name, fc.name])
|
|
|
|
|
self.program = program
|
|
|
|
|
|
|
|
|
|
def test_inplace_ops(self):
|
|
|
|
|
print("before optimization")
|
|
|
|
|
print(str(self.program))
|
|
|
|
|
result_program = memory_optimize(self.program)
|
|
|
|
|
print("after optimization")
|
|
|
|
|
print(str(result_program))
|
|
|
|
|
result_program = self.program.clone()
|
|
|
|
|
memory_optimize(self.program)
|
|
|
|
|
old_vars = _get_vars(self.program)
|
|
|
|
|
new_vars = _get_vars(result_program)
|
|
|
|
|
self.assertTrue(old_vars != new_vars)
|
|
|
|
|
|
|
|
|
|
def test_skip_opt(self):
|
|
|
|
|
result_program = self.program.clone()
|
|
|
|
|
memory_optimize(self.program, skip_opt_set=self.skip_set)
|
|
|
|
|
old_vars = _get_vars(self.program)
|
|
|
|
|
new_vars = _get_vars(result_program)
|
|
|
|
|
self.assertTrue(old_vars != new_vars)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestMemoryTranspiler3(unittest.TestCase):
|
|
|
|
|