|
|
|
@ -15,7 +15,8 @@
|
|
|
|
|
from . import core
|
|
|
|
|
import numpy
|
|
|
|
|
import os
|
|
|
|
|
import six.moves as six
|
|
|
|
|
import six
|
|
|
|
|
from six.moves import zip, range, xrange
|
|
|
|
|
import multiprocessing
|
|
|
|
|
|
|
|
|
|
from .framework import Variable, default_main_program
|
|
|
|
@ -52,7 +53,7 @@ class DataToLoDTensorConverter(object):
|
|
|
|
|
self.data = []
|
|
|
|
|
self.lod = []
|
|
|
|
|
|
|
|
|
|
for i in six.range(lod_level):
|
|
|
|
|
for i in six.moves.range(lod_level):
|
|
|
|
|
self.lod.append([])
|
|
|
|
|
|
|
|
|
|
def feed(self, data):
|
|
|
|
@ -141,7 +142,7 @@ class DataFeeder(object):
|
|
|
|
|
if program is None:
|
|
|
|
|
program = default_main_program()
|
|
|
|
|
for each_var in feed_list:
|
|
|
|
|
if isinstance(each_var, str):
|
|
|
|
|
if isinstance(each_var, six.string_types):
|
|
|
|
|
each_var = program.block(0).var(each_var)
|
|
|
|
|
if not isinstance(each_var, Variable):
|
|
|
|
|
raise TypeError("Feed list should contain a list of variable")
|
|
|
|
@ -173,7 +174,7 @@ class DataFeeder(object):
|
|
|
|
|
dict: the result of conversion.
|
|
|
|
|
"""
|
|
|
|
|
converter = []
|
|
|
|
|
for lod_level, shape, dtype in six.zip(
|
|
|
|
|
for lod_level, shape, dtype in six.moves.zip(
|
|
|
|
|
self.feed_lod_level, self.feed_shapes, self.feed_dtypes):
|
|
|
|
|
converter.append(
|
|
|
|
|
DataToLoDTensorConverter(
|
|
|
|
@ -186,10 +187,12 @@ class DataFeeder(object):
|
|
|
|
|
assert len(each_sample) == len(converter), (
|
|
|
|
|
"The number of fields in data (%s) does not match " +
|
|
|
|
|
"len(feed_list) (%s)") % (len(each_sample), len(converter))
|
|
|
|
|
for each_converter, each_slot in six.zip(converter, each_sample):
|
|
|
|
|
for each_converter, each_slot in six.moves.zip(converter,
|
|
|
|
|
each_sample):
|
|
|
|
|
each_converter.feed(each_slot)
|
|
|
|
|
ret_dict = {}
|
|
|
|
|
for each_name, each_converter in six.zip(self.feed_names, converter):
|
|
|
|
|
for each_name, each_converter in six.moves.zip(self.feed_names,
|
|
|
|
|
converter):
|
|
|
|
|
ret_dict[each_name] = each_converter.done()
|
|
|
|
|
return ret_dict
|
|
|
|
|
|
|
|
|
@ -211,12 +214,14 @@ class DataFeeder(object):
|
|
|
|
|
if isinstance(self.place, core.CUDAPlace):
|
|
|
|
|
places = [
|
|
|
|
|
core.CUDAPlace(i)
|
|
|
|
|
for i in six.xrange(self._get_number_of_places_(num_places))
|
|
|
|
|
for i in six.moves.xrange(
|
|
|
|
|
self._get_number_of_places_(num_places))
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
places = [
|
|
|
|
|
core.CPUPlace()
|
|
|
|
|
for _ in six.xrange(self._get_number_of_places_(num_places))
|
|
|
|
|
for _ in six.moves.xrange(
|
|
|
|
|
self._get_number_of_places_(num_places))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if len(iterable) != len(places):
|
|
|
|
@ -226,7 +231,7 @@ class DataFeeder(object):
|
|
|
|
|
"must be same.")
|
|
|
|
|
|
|
|
|
|
place = self.place
|
|
|
|
|
for p, batch in six.zip(places, iterable):
|
|
|
|
|
for p, batch in six.moves.zip(places, iterable):
|
|
|
|
|
self.place = p
|
|
|
|
|
yield self.feed(batch)
|
|
|
|
|
self.place = place
|
|
|
|
|