|
|
|
@ -208,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True):
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_program_cache_key(feed, fetch_list):
|
|
|
|
|
feed_var_names = list(feed.keys())
|
|
|
|
|
def _to_name_str(var):
|
|
|
|
|
if isinstance(var, Variable):
|
|
|
|
|
return var.desc.name()
|
|
|
|
|
elif isinstance(var, str):
|
|
|
|
|
return var
|
|
|
|
|
elif isinstance(var, six.string_types):
|
|
|
|
|
return str(var)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(str(var) + " should be Variable or str")
|
|
|
|
|
|
|
|
|
|
def to_name_str(var):
|
|
|
|
|
if isinstance(var, Variable):
|
|
|
|
|
return var.desc.name()
|
|
|
|
|
elif isinstance(var, str):
|
|
|
|
|
return var
|
|
|
|
|
elif isinstance(var, six.string_types):
|
|
|
|
|
return str(var)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(str(var) + " should be Variable or str")
|
|
|
|
|
|
|
|
|
|
fetch_var_names = list(map(to_name_str, fetch_list))
|
|
|
|
|
def _get_program_cache_key(feed, fetch_list):
|
|
|
|
|
feed_var_names = list(feed.keys())
|
|
|
|
|
fetch_var_names = list(map(_to_name_str, fetch_list))
|
|
|
|
|
|
|
|
|
|
return str(feed_var_names + fetch_var_names)
|
|
|
|
|
|
|
|
|
@ -397,11 +397,8 @@ class Executor(object):
|
|
|
|
|
self.executor.close()
|
|
|
|
|
self._closed = True
|
|
|
|
|
|
|
|
|
|
def _run_parallel(self,
|
|
|
|
|
scope,
|
|
|
|
|
feed=None,
|
|
|
|
|
fetch_list=None,
|
|
|
|
|
return_numpy=True):
|
|
|
|
|
def _run_parallel(self, scope, feed, fetch_list, fetch_var_name,
|
|
|
|
|
return_numpy):
|
|
|
|
|
if isinstance(feed, dict):
|
|
|
|
|
feed_tensor_dict = dict()
|
|
|
|
|
for feed_name in feed:
|
|
|
|
@ -437,8 +434,8 @@ class Executor(object):
|
|
|
|
|
res.append(res_dict)
|
|
|
|
|
self.executor.feed_tensors_into_local_scopes(res)
|
|
|
|
|
|
|
|
|
|
fetch_var_name = '@FETCHED_VAR_NAME@'
|
|
|
|
|
self.executor.run(fetch_list, fetch_var_name)
|
|
|
|
|
fetch_var_names = list(map(_to_name_str, fetch_list))
|
|
|
|
|
self.executor.run(fetch_var_names, fetch_var_name)
|
|
|
|
|
arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
|
|
|
|
|
|
|
|
|
|
if return_numpy:
|
|
|
|
@ -504,6 +501,8 @@ class Executor(object):
|
|
|
|
|
|
|
|
|
|
if scope is None:
|
|
|
|
|
scope = global_scope()
|
|
|
|
|
if fetch_list is None:
|
|
|
|
|
fetch_list = []
|
|
|
|
|
|
|
|
|
|
compiled = isinstance(program, compiler.CompiledProgram)
|
|
|
|
|
# For backward compatibility, run directly.
|
|
|
|
@ -529,6 +528,7 @@ class Executor(object):
|
|
|
|
|
scope=scope,
|
|
|
|
|
feed=feed,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
fetch_var_name=fetch_var_name,
|
|
|
|
|
return_numpy=return_numpy)
|
|
|
|
|
else:
|
|
|
|
|
# TODO(panyx0718): Can compile program to optimize executor
|
|
|
|
@ -552,8 +552,6 @@ class Executor(object):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"feed requires dict as its Parameter. But you passed in %s" %
|
|
|
|
|
(type(feed)))
|
|
|
|
|
if fetch_list is None:
|
|
|
|
|
fetch_list = []
|
|
|
|
|
if program is None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
|
|
|
|
|