|
|
|
@ -123,28 +123,65 @@ class ParallelExecutor(object):
|
|
|
|
|
allow_op_delay)
|
|
|
|
|
self.scope = scope
|
|
|
|
|
|
|
|
|
|
def run(self, fetch_list, feed_dict={}):
|
|
|
|
|
def run(self, fetch_list, feed=None, feed_dict=None):
|
|
|
|
|
"""
|
|
|
|
|
:param fetch_list: A list of variable names that will be fetched.
|
|
|
|
|
:param feed_dict: A dict mapping for feed variable name to LoDTensor
|
|
|
|
|
or numpy array.
|
|
|
|
|
:return: fetched value list.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(feed_dict, dict):
|
|
|
|
|
raise TypeError("feed_dict should be a dict")
|
|
|
|
|
|
|
|
|
|
feed_tensor_dict = {}
|
|
|
|
|
for i, feed_name in enumerate(feed_dict):
|
|
|
|
|
feed_tensor = feed_dict[feed_name]
|
|
|
|
|
if not isinstance(feed_tensor, core.LoDTensor):
|
|
|
|
|
feed_tensor = core.LoDTensor()
|
|
|
|
|
feed_tensor.set(feed_dict[feed_name], self._act_places[0])
|
|
|
|
|
feed_tensor_dict[feed_name] = feed_tensor
|
|
|
|
|
Args:
|
|
|
|
|
fetch_list(list): The fetched variable names
|
|
|
|
|
feed(list|dict|None): The feed variables. If the feed is a dict, tensors in that dict will be splitted
|
|
|
|
|
into each devices. If the feed is a list, each element of the list will be copied to each device.
|
|
|
|
|
feed_dict: Alias for feed parameter, for backward compatibility.
|
|
|
|
|
|
|
|
|
|
Returns: fetched result list.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if feed is None:
|
|
|
|
|
feed = feed_dict
|
|
|
|
|
|
|
|
|
|
if isinstance(feed, dict):
|
|
|
|
|
feed_tensor_dict = dict()
|
|
|
|
|
for feed_name in feed:
|
|
|
|
|
feed_tensor = feed[feed_name]
|
|
|
|
|
if not isinstance(feed_tensor, core.LoDTensor):
|
|
|
|
|
feed_tensor = core.LoDTensor()
|
|
|
|
|
# always set to CPU place, since the tensor need to be splitted
|
|
|
|
|
# it is fast in CPU
|
|
|
|
|
feed_tensor.set(feed[feed_name], core.CPUPlace())
|
|
|
|
|
feed_tensor_dict[feed_name] = feed_tensor
|
|
|
|
|
|
|
|
|
|
self.executor.feed_and_split_tensor_into_local_scopes(
|
|
|
|
|
feed_tensor_dict)
|
|
|
|
|
elif isinstance(feed, list) or isinstance(feed, tuple):
|
|
|
|
|
if len(feed) != len(self._act_places):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Feed a list of tensor, the list should be the same size as places"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
res = list()
|
|
|
|
|
|
|
|
|
|
for i, each in enumerate(feed):
|
|
|
|
|
if not isinstance(each, dict):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Each element of feed list should be a dict")
|
|
|
|
|
res_dict = dict()
|
|
|
|
|
for feed_name in each:
|
|
|
|
|
tensor = each[feed_name]
|
|
|
|
|
if not isinstance(tensor, core.LoDTensor):
|
|
|
|
|
tmp = core.LoDTensor()
|
|
|
|
|
tmp.set(tensor, self._act_places[i])
|
|
|
|
|
tensor = tmp
|
|
|
|
|
res_dict[feed_name] = tensor
|
|
|
|
|
res.append(res_dict)
|
|
|
|
|
self.executor.feed_tensors_into_local_scopes(res)
|
|
|
|
|
|
|
|
|
|
fetch_var_name = '@FETCHED_VAR_NAME@'
|
|
|
|
|
self.executor.run(fetch_list, fetch_var_name, feed_tensor_dict)
|
|
|
|
|
self.executor.run(fetch_list, fetch_var_name)
|
|
|
|
|
arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array()
|
|
|
|
|
return [arr[i] for i in range(len(arr))]
|
|
|
|
|
|
|
|
|
|
def bcast_params(self):
|
|
|
|
|
self.executor.bcast_params(set(self.persistable_vars))
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def device_count(self):
|
|
|
|
|
return len(self._act_places)
|
|
|
|
|