|
|
|
@ -16,6 +16,7 @@ import core
|
|
|
|
|
import multiprocessing
|
|
|
|
|
import framework
|
|
|
|
|
import executor
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
__all__ = ['ParallelExecutor']
|
|
|
|
|
|
|
|
|
@ -125,6 +126,30 @@ class ParallelExecutor(object):
|
|
|
|
|
|
|
|
|
|
def run(self, fetch_list, feed=None, feed_dict=None):
|
|
|
|
|
"""
|
|
|
|
|
Run a parallel executor with fetch_list.
|
|
|
|
|
|
|
|
|
|
The feed parameter can be a dict or a list. If feed is a dict, the
|
|
|
|
|
feed data will be split into multiple devices. If feed is a list, we
|
|
|
|
|
assume the data has been splitted into multiple devices, the each
|
|
|
|
|
element in the list will be copied to each device directly.
|
|
|
|
|
|
|
|
|
|
For example, if the feed is a dict:
|
|
|
|
|
>>> exe = ParallelExecutor()
|
|
|
|
|
>>> # the image will be splitted into devices. If there is two devices
|
|
|
|
|
>>> # each device will process an image with shape (24, 1, 28, 28)
|
|
|
|
|
>>> exe.run(feed={'image': numpy.random.random(size=(48, 1, 28, 28))})
|
|
|
|
|
|
|
|
|
|
For example, if the feed is a list:
|
|
|
|
|
>>> exe = ParallelExecutor()
|
|
|
|
|
>>> # each device will process each element in the list.
|
|
|
|
|
>>> # the 1st device will process an image with shape (48, 1, 28, 28)
|
|
|
|
|
>>> # the 2nd device will process an image with shape (32, 1, 28, 28)
|
|
|
|
|
>>> #
|
|
|
|
|
>>> # you can use exe.device_count to get the device number.
|
|
|
|
|
>>> exe.run(feed=[{"image": numpy.random.random(size=(48, 1, 28, 28))},
|
|
|
|
|
>>> {"image": numpy.random.random(size=(32, 1, 28, 28))},
|
|
|
|
|
>>> ])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fetch_list(list): The fetched variable names
|
|
|
|
@ -133,12 +158,14 @@ class ParallelExecutor(object):
|
|
|
|
|
the feed is a list, each element of the list will be copied
|
|
|
|
|
to each device.
|
|
|
|
|
feed_dict: Alias for feed parameter, for backward compatibility.
|
|
|
|
|
This parameter is deprecated.
|
|
|
|
|
|
|
|
|
|
Returns: fetched result list.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if feed is None:
|
|
|
|
|
feed = feed_dict
|
|
|
|
|
print >> sys.stderr, "`feed_dict` is deprecated. Please use `feed=`"
|
|
|
|
|
|
|
|
|
|
if isinstance(feed, dict):
|
|
|
|
|
feed_tensor_dict = dict()
|
|
|
|
|