fix memory opt skip set by name (#14774)

* random failed. rerun ci. test=develop

* windows failed. rerun ci. test=develop
ce_debug
dzhwinter 7 years ago committed by chengduo
parent c4c5f0b8ca
commit 00776b167a

@ -22,6 +22,15 @@ from paddle.fluid.framework import Program, program_guard
from paddle.fluid.transpiler import memory_optimize
def _get_vars(prog):
assert (isinstance(prog, Program))
all_vars = set()
for op in prog.global_block().ops:
all_vars.update(op.input_arg_names)
all_vars.update(op.output_arg_names)
return all_vars
class TestControlFlowGraph(unittest.TestCase):
def setUp(self):
program = Program()
@ -37,11 +46,11 @@ class TestControlFlowGraph(unittest.TestCase):
self.program = program
def test_control_flow_graph(self):
print("before optimization")
print(str(self.program))
result_program = memory_optimize(self.program)
print("after optimization")
print(str(result_program))
result_program = self.program.clone()
memory_optimize(self.program)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
class TestMemoryTranspiler2(unittest.TestCase):
@ -58,14 +67,22 @@ class TestMemoryTranspiler2(unittest.TestCase):
avg_cost = layers.mean(cost)
opt = optimizer.SGD(learning_rate=0.001)
opt.minimize(avg_cost)
self.skip_set = set([cost.name, fc.name])
self.program = program
def test_inplace_ops(self):
print("before optimization")
print(str(self.program))
result_program = memory_optimize(self.program)
print("after optimization")
print(str(result_program))
result_program = self.program.clone()
memory_optimize(self.program)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
def test_skip_opt(self):
result_program = self.program.clone()
memory_optimize(self.program, skip_opt_set=self.skip_set)
old_vars = _get_vars(self.program)
new_vars = _get_vars(result_program)
self.assertTrue(old_vars != new_vars)
class TestMemoryTranspiler3(unittest.TestCase):

@ -14,6 +14,7 @@
from __future__ import print_function
import six
from collections import defaultdict, MutableSet
from .. import core
from ... import compat as cpt
@ -470,8 +471,21 @@ def memory_optimize(input_program,
Returns:
None
"""
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
if level != 0 and level != 1:
raise ValueError("only support opt_level 0 or 1.")
if skip_opt_set is not None and not isinstance(skip_opt_set, set):
raise ValueError("only support skip_opt_set as set.")
global PRINT_LOG
PRINT_LOG = print_log
if skip_grads:
@ -486,6 +500,8 @@ def memory_optimize(input_program,
skip_opt_set = grad_set
else:
skip_opt_set.update(grad_set)
if skip_opt_set is not None:
skip_opt_set = set(map(to_name_str, skip_opt_set))
cfgs = _get_cfgs(input_program)
for cfg in cfgs:
cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)

Loading…
Cancel
Save