|
|
|
@ -16,6 +16,7 @@ from __future__ import print_function
|
|
|
|
|
import core
|
|
|
|
|
import numpy
|
|
|
|
|
import six.moves as six
|
|
|
|
|
import multiprocessing
|
|
|
|
|
|
|
|
|
|
from framework import Variable, default_main_program
|
|
|
|
|
|
|
|
|
@ -116,3 +117,60 @@ class DataFeeder(object):
|
|
|
|
|
for each_name, each_converter in six.zip(self.feed_names, converter):
|
|
|
|
|
ret_dict[each_name] = each_converter.done()
|
|
|
|
|
return ret_dict
|
|
|
|
|
|
|
|
|
|
def feed_parallel(self, iterable, num_places=None):
|
|
|
|
|
if isinstance(self.place, core.CUDAPlace):
|
|
|
|
|
places = [
|
|
|
|
|
core.CUDAPlace(i)
|
|
|
|
|
for i in six.xrange(self._get_number_of_places_(num_places))
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
places = [
|
|
|
|
|
core.CPUPlace()
|
|
|
|
|
for _ in six.xrange(self._get_number_of_places_(num_places))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if len(iterable) != len(places):
|
|
|
|
|
raise ValueError("feed_parallel takes multiple mini-batches. Each "
|
|
|
|
|
"mini-batch will be feed on each device. The "
|
|
|
|
|
"number of devices and number of mini-batches "
|
|
|
|
|
"must be same.")
|
|
|
|
|
|
|
|
|
|
place = self.place
|
|
|
|
|
for p, batch in six.zip(places, iterable):
|
|
|
|
|
self.place = p
|
|
|
|
|
yield self.feed(batch)
|
|
|
|
|
self.place = place
|
|
|
|
|
|
|
|
|
|
def _get_number_of_places_(self, num_places):
|
|
|
|
|
if num_places is not None:
|
|
|
|
|
return int(num_places)
|
|
|
|
|
elif isinstance(self.place, core.CUDAPlace):
|
|
|
|
|
return core.get_cuda_device_count()
|
|
|
|
|
else:
|
|
|
|
|
return multiprocessing.cpu_count()
|
|
|
|
|
|
|
|
|
|
def decorate_reader(self,
|
|
|
|
|
reader,
|
|
|
|
|
multi_devices,
|
|
|
|
|
num_places=None,
|
|
|
|
|
drop_last=True):
|
|
|
|
|
def __reader_creator__():
|
|
|
|
|
if not multi_devices:
|
|
|
|
|
for item in reader():
|
|
|
|
|
yield self.feed(item)
|
|
|
|
|
else:
|
|
|
|
|
num = self._get_number_of_places_(num_places)
|
|
|
|
|
item = []
|
|
|
|
|
for batch in reader():
|
|
|
|
|
item.append(batch)
|
|
|
|
|
if len(item) == num:
|
|
|
|
|
yield list(self.feed_parallel(item, num))
|
|
|
|
|
item = []
|
|
|
|
|
if not drop_last and len(item) != 0:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The data batch which cannot fit for devices will be "
|
|
|
|
|
"dropped is not implementation. Other strategies are "
|
|
|
|
|
"not implemented")
|
|
|
|
|
|
|
|
|
|
return __reader_creator__
|
|
|
|
|