|
|
|
@ -235,6 +235,77 @@ class Executor(object):
|
|
|
|
|
tensor.set_lod(lod)
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
def _get_program_cache(self, program_cache_key):
|
|
|
|
|
return self.program_caches.get(program_cache_key, None)
|
|
|
|
|
|
|
|
|
|
def _add_program_cache(self, program_cache_key, program):
|
|
|
|
|
self.program_caches[program_cache_key] = program
|
|
|
|
|
|
|
|
|
|
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
|
|
|
|
|
fetch_var_name):
|
|
|
|
|
tmp_program = program.clone()
|
|
|
|
|
|
|
|
|
|
global_block = tmp_program.global_block()
|
|
|
|
|
|
|
|
|
|
if feed_var_name in global_block.vars:
|
|
|
|
|
feed_var = global_block.var(feed_var_name)
|
|
|
|
|
else:
|
|
|
|
|
feed_var = global_block.create_var(
|
|
|
|
|
name=feed_var_name,
|
|
|
|
|
type=core.VarDesc.VarType.FEED_MINIBATCH,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
if fetch_var_name in global_block.vars:
|
|
|
|
|
fetch_var = global_block.var(fetch_var_name)
|
|
|
|
|
else:
|
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
|
name=fetch_var_name,
|
|
|
|
|
type=core.VarDesc.VarType.FETCH_LIST,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
# prepend feed operators
|
|
|
|
|
if not has_feed_operators(global_block, feed, feed_var_name):
|
|
|
|
|
for i, name in enumerate(feed):
|
|
|
|
|
out = global_block.var(name)
|
|
|
|
|
global_block.prepend_op(
|
|
|
|
|
type='feed',
|
|
|
|
|
inputs={'X': [feed_var]},
|
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
# append fetch_operators
|
|
|
|
|
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
|
|
|
|
|
for i, var in enumerate(fetch_list):
|
|
|
|
|
assert isinstance(var, Variable) or isinstance(var, str), (
|
|
|
|
|
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
|
|
|
|
|
global_block.append_op(
|
|
|
|
|
type='fetch',
|
|
|
|
|
inputs={'X': [var]},
|
|
|
|
|
outputs={'Out': [fetch_var]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
return tmp_program
|
|
|
|
|
|
|
|
|
|
def _feed_data(self, program, feed, feed_var_name, scope):
|
|
|
|
|
# feed var to framework
|
|
|
|
|
for op in program.global_block().ops:
|
|
|
|
|
if op.desc.type() == 'feed':
|
|
|
|
|
feed_target_name = op.desc.output('Out')[0]
|
|
|
|
|
cur_feed = feed[feed_target_name]
|
|
|
|
|
if not isinstance(cur_feed, core.LoDTensor):
|
|
|
|
|
cur_feed = self.aslodtensor(cur_feed)
|
|
|
|
|
idx = op.desc.attr('col')
|
|
|
|
|
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
def _fetch_data(self, fetch_list, fetch_var_name, scope):
|
|
|
|
|
outs = [
|
|
|
|
|
core.get_fetch_variable(scope, fetch_var_name, i)
|
|
|
|
|
for i in xrange(len(fetch_list))
|
|
|
|
|
]
|
|
|
|
|
return outs
|
|
|
|
|
|
|
|
|
|
def run(self,
|
|
|
|
|
program=None,
|
|
|
|
|
feed=None,
|
|
|
|
@ -268,7 +339,6 @@ class Executor(object):
|
|
|
|
|
raise TypeError("feed should be a map")
|
|
|
|
|
if fetch_list is None:
|
|
|
|
|
fetch_list = []
|
|
|
|
|
|
|
|
|
|
if program is None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
|
|
|
|
@ -278,79 +348,30 @@ class Executor(object):
|
|
|
|
|
if scope is None:
|
|
|
|
|
scope = global_scope()
|
|
|
|
|
|
|
|
|
|
program_cache = None
|
|
|
|
|
program_cache_key = get_program_cache_key(feed, fetch_list)
|
|
|
|
|
|
|
|
|
|
cache_key = get_program_cache_key(feed, fetch_list)
|
|
|
|
|
if use_program_cache:
|
|
|
|
|
# find program cache by cache_key
|
|
|
|
|
program_cache = self.program_caches.get(program_cache_key, None)
|
|
|
|
|
# TODO(qiao): Should check program_cache and program are exactly the same.
|
|
|
|
|
cached_program = self._get_program_cache(cache_key)
|
|
|
|
|
if cached_program is None:
|
|
|
|
|
cached_program = self._add_feed_fetch_ops(
|
|
|
|
|
program=program,
|
|
|
|
|
feed=feed,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
feed_var_name=feed_var_name,
|
|
|
|
|
fetch_var_name=fetch_var_name)
|
|
|
|
|
self._add_program_cache(cache_key, cached_program)
|
|
|
|
|
program = cached_program
|
|
|
|
|
else:
|
|
|
|
|
self.program_caches.pop(program_cache_key, None)
|
|
|
|
|
|
|
|
|
|
if program_cache is None:
|
|
|
|
|
program_cache = program.clone()
|
|
|
|
|
|
|
|
|
|
if use_program_cache:
|
|
|
|
|
self.program_caches[program_cache_key] = program_cache
|
|
|
|
|
|
|
|
|
|
global_block = program_cache.global_block()
|
|
|
|
|
|
|
|
|
|
if feed_var_name in global_block.vars:
|
|
|
|
|
feed_var = global_block.var(feed_var_name)
|
|
|
|
|
else:
|
|
|
|
|
feed_var = global_block.create_var(
|
|
|
|
|
name=feed_var_name,
|
|
|
|
|
type=core.VarDesc.VarType.FEED_MINIBATCH,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
if fetch_var_name in global_block.vars:
|
|
|
|
|
fetch_var = global_block.var(fetch_var_name)
|
|
|
|
|
else:
|
|
|
|
|
fetch_var = global_block.create_var(
|
|
|
|
|
name=fetch_var_name,
|
|
|
|
|
type=core.VarDesc.VarType.FETCH_LIST,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
# prepend feed operators
|
|
|
|
|
if not has_feed_operators(global_block, feed, feed_var_name):
|
|
|
|
|
for i, name in enumerate(feed):
|
|
|
|
|
out = global_block.var(name)
|
|
|
|
|
global_block.prepend_op(
|
|
|
|
|
type='feed',
|
|
|
|
|
inputs={'X': [feed_var]},
|
|
|
|
|
outputs={'Out': [out]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
# append fetch_operators
|
|
|
|
|
if not has_fetch_operators(global_block, fetch_list,
|
|
|
|
|
fetch_var_name):
|
|
|
|
|
for i, var in enumerate(fetch_list):
|
|
|
|
|
assert isinstance(var, Variable) or isinstance(var, str), (
|
|
|
|
|
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
|
|
|
|
|
global_block.append_op(
|
|
|
|
|
type='fetch',
|
|
|
|
|
inputs={'X': [var]},
|
|
|
|
|
outputs={'Out': [fetch_var]},
|
|
|
|
|
attrs={'col': i})
|
|
|
|
|
|
|
|
|
|
# feed var to framework
|
|
|
|
|
for op in program_cache.global_block().ops:
|
|
|
|
|
if op.desc.type() == 'feed':
|
|
|
|
|
feed_target_name = op.desc.output('Out')[0]
|
|
|
|
|
cur_feed = feed[feed_target_name]
|
|
|
|
|
if not isinstance(cur_feed, core.LoDTensor):
|
|
|
|
|
cur_feed = self.aslodtensor(cur_feed)
|
|
|
|
|
idx = op.desc.attr('col')
|
|
|
|
|
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
self.executor.run(program_cache.desc, scope, 0, True, True)
|
|
|
|
|
outs = [
|
|
|
|
|
core.get_fetch_variable(scope, fetch_var_name, i)
|
|
|
|
|
for i in xrange(len(fetch_list))
|
|
|
|
|
]
|
|
|
|
|
self.program_caches.pop(cache_key, None)
|
|
|
|
|
program = self._add_feed_fetch_ops(
|
|
|
|
|
program=program,
|
|
|
|
|
feed=feed,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
feed_var_name=feed_var_name,
|
|
|
|
|
fetch_var_name=fetch_var_name)
|
|
|
|
|
|
|
|
|
|
self._feed_data(program, feed, feed_var_name, scope)
|
|
|
|
|
self.executor.run(program.desc, scope, 0, True, True)
|
|
|
|
|
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
|
|
|
|
|
if return_numpy:
|
|
|
|
|
outs = as_numpy(outs)
|
|
|
|
|
return outs
|
|
|
|
|