test=develop
revert-15207-remove_op_handle_lock_and_fix_var
Xin Pan 6 years ago
parent 5f0a0286e0
commit c4b09a713f

@ -77,6 +77,7 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx;
};
class OpBase;
class VarBase {

@ -208,20 +208,20 @@ def _fetch_var(name, scope=None, return_numpy=True):
return tensor
def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys())
def _to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
def to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
else:
raise TypeError(str(var) + " should be Variable or str")
fetch_var_names = list(map(to_name_str, fetch_list))
def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys())
fetch_var_names = list(map(_to_name_str, fetch_list))
return str(feed_var_names + fetch_var_names)
@ -397,11 +397,8 @@ class Executor(object):
self.executor.close()
self._closed = True
def _run_parallel(self,
scope,
feed=None,
fetch_list=None,
return_numpy=True):
def _run_parallel(self, scope, feed, fetch_list, fetch_var_name,
return_numpy):
if isinstance(feed, dict):
feed_tensor_dict = dict()
for feed_name in feed:
@ -437,8 +434,8 @@ class Executor(object):
res.append(res_dict)
self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@'
self.executor.run(fetch_list, fetch_var_name)
fetch_var_names = list(map(_to_name_str, fetch_list))
self.executor.run(fetch_var_names, fetch_var_name)
arr = scope.find_var(fetch_var_name).get_lod_tensor_array()
if return_numpy:
@ -504,6 +501,8 @@ class Executor(object):
if scope is None:
scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly.
@ -529,6 +528,7 @@ class Executor(object):
scope=scope,
feed=feed,
fetch_list=fetch_list,
fetch_var_name=fetch_var_name,
return_numpy=return_numpy)
else:
# TODO(panyx0718): Can compile program to optimize executor
@ -552,8 +552,6 @@ class Executor(object):
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s" %
(type(feed)))
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()

@ -279,7 +279,7 @@ class ParallelExecutor(object):
res.append(res_dict)
self.executor.feed_tensors_into_local_scopes(res)
fetch_var_name = '@FETCHED_VAR_NAME@'
fetch_var_name = 'fetch'
self.executor.run(fetch_list, fetch_var_name)
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()

Loading…
Cancel
Save