fix pruned_program_cache_key of Operator (#23594)

* fix init_gflags with 'python -c', test=develop

* fix pruned_program_cache_key of Operator, test=develop
revert-23830-2.0-beta
Leo Chen 6 years ago committed by GitHub
parent 2c4b57e94b
commit 02b4e989b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -354,7 +354,7 @@ def _to_name_str(var):
elif isinstance(var, six.string_types): elif isinstance(var, six.string_types):
return str(var) return str(var)
elif isinstance(var, Operator): elif isinstance(var, Operator):
return var.desc.type() return str(id(var))
else: else:
raise TypeError(str(var) + " should be Variable, Operator or str") raise TypeError(str(var) + " should be Variable, Operator or str")

@ -323,7 +323,7 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
def test_prune_with_cache_program(self): def test_prune_with_cache_program(self):
''' '''
When use_prune=True and use_program_cache=True, Executor should cache the pruned program. When use_prune=True, Executor should cache the pruned program.
If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program, If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
and needn't to call _prune_program() to prune the program. and needn't to call _prune_program() to prune the program.
In this test, we hack the Executor._prune_program with a mock function which do nothing but increase In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
@ -350,16 +350,68 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
feed={'x': x_np, feed={'x': x_np,
'label': label_np}, 'label': label_np},
fetch_list=[loss1.name], fetch_list=[loss1.name],
use_prune=True, use_prune=True)
use_program_cache=True)
if i == 0: if i == 0:
self.assertEqual(exe.prune_called_times, 1) self.assertEqual(exe.prune_called_times, 1)
else: else:
self.assertEqual(exe.prune_called_times, 1) self.assertEqual(exe.prune_called_times, 1)
def test_prune_with_cache_program2(self):
'''
When use_prune=True, Executor should cache the pruned program.
If the only difference in fetch_list is optimize_ops during multiple runs,
the cache_keys should be different and get different pruned program.
'''
with _mock_guard(mock):
exe = fluid.Executor(fluid.CPUPlace())
exe.prune_called_times = 0
program = framework.Program()
startup_program = framework.Program()
scope = fluid.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(program, startup_program):
(x1, x2, y1, y2, label, loss1, loss2, w1_param_attrs,
w2_param_attrs) = self.net2()
adam_optimizer1 = fluid.optimizer.AdamOptimizer(
learning_rate=0.5)
train1 = adam_optimizer1.minimize(loss1)
adam_optimizer2 = fluid.optimizer.AdamOptimizer(
learning_rate=0.5)
train2 = adam_optimizer2.minimize(loss2)
exe.run(startup_program)
x_np = np.random.random(size=(10, 2)).astype('float32')
label_np = np.random.randint(
1, size=(10, 1)).astype('int64')
for i in range(10):
if i % 2:
res = exe.run(program,
feed={
'x1': x_np,
'x2': x_np,
'label': label_np
},
fetch_list=[loss1, loss2, train1],
use_prune=True)
else:
res = exe.run(program,
feed={
'x1': x_np,
'x2': x_np,
'label': label_np
},
fetch_list=[loss1, loss2, train2],
use_prune=True)
if i == 0:
self.assertEqual(exe.prune_called_times, 1)
elif i == 1:
self.assertEqual(exe.prune_called_times, 2)
else:
self.assertEqual(exe.prune_called_times, 2)
def test_prune_with_cache_compiled_program(self): def test_prune_with_cache_compiled_program(self):
''' '''
When use_prune=True and use_program_cache=True, Executor should cache the pruned program. When use_prune=True, Executor should cache the pruned program.
If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program, If in next run, the program, feed, fetch are not changed, Executor use the cached pruned program,
and needn't to call _prune_program() to prune the program. and needn't to call _prune_program() to prune the program.
In this test, we hack the Executor._prune_program with a mock function which do nothing but increase In this test, we hack the Executor._prune_program with a mock function which do nothing but increase
@ -389,8 +441,7 @@ class TestExecutorRunAutoPrune(unittest.TestCase):
feed={'x': x_np, feed={'x': x_np,
'label': label_np}, 'label': label_np},
fetch_list=[loss1.name], fetch_list=[loss1.name],
use_prune=True, use_prune=True)
use_program_cache=True)
if i == 0: if i == 0:
self.assertEqual(exe.prune_called_times, 1) self.assertEqual(exe.prune_called_times, 1)
else: else:

Loading…
Cancel
Save