|
|
|
@ -226,7 +226,19 @@ class Executor(object):
|
|
|
|
|
feed_var_name='feed',
|
|
|
|
|
fetch_var_name='fetch',
|
|
|
|
|
scope=None,
|
|
|
|
|
return_numpy=True):
|
|
|
|
|
return_numpy=True,
|
|
|
|
|
use_program_cache=False):
|
|
|
|
|
"""
|
|
|
|
|
:param program: the program that need to run
|
|
|
|
|
:param feed: feed variable list
|
|
|
|
|
:param fetch_list: fetch variable list
|
|
|
|
|
:param feed_var_name: feed_var_name default to 'feed'
|
|
|
|
|
:param fetch_var_name: fetch_var_name default to 'fetch'
|
|
|
|
|
:param scope: the scope used to run this program, you can switch it to different scope.
|
|
|
|
|
:param return_numpy: convert the fetched tensor to numpy
|
|
|
|
|
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if feed is None:
|
|
|
|
|
feed = {}
|
|
|
|
|
if fetch_list is None:
|
|
|
|
@ -244,7 +256,7 @@ class Executor(object):
|
|
|
|
|
program_cache_key = str(feed.keys() + fetch_list)
|
|
|
|
|
program_cache = self.program_caches.get(program_cache_key, None)
|
|
|
|
|
|
|
|
|
|
if program_cache is None:
|
|
|
|
|
if program_cache is None or not use_program_cache:
|
|
|
|
|
program_cache = program.clone()
|
|
|
|
|
self.program_caches[program_cache_key] = program_cache
|
|
|
|
|
|
|
|
|
@ -266,6 +278,7 @@ class Executor(object):
|
|
|
|
|
type=core.VarDesc.VarType.FETCH_LIST,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
# prepend feed operators
|
|
|
|
|
if not has_feed_operators(global_block, feed, feed_var_name):
|
|
|
|
|
for i, name in enumerate(feed):
|
|
|
|
|
out = global_block.var(name)
|
|
|
|
@ -275,6 +288,7 @@ class Executor(object):
|
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
# append fetch_operators
|
|
|
|
|
if not has_fetch_operators(global_block, fetch_list,
|
|
|
|
|
fetch_var_name):
|
|
|
|
|
for i, var in enumerate(fetch_list):
|
|
|
|
|