|
|
|
@ -62,8 +62,8 @@ class DataFeeder(DataProviderConverter):
|
|
|
|
|
self.reader_dict = reader_dict
|
|
|
|
|
for each in data_types:
|
|
|
|
|
self.input_names.append(each[0])
|
|
|
|
|
self.input_types.append(each[1])
|
|
|
|
|
assert isinstance(each[1], data_type.InputType)
|
|
|
|
|
self.input_types.append(each[1])
|
|
|
|
|
DataProviderConverter.__init__(self, self.input_types)
|
|
|
|
|
|
|
|
|
|
def convert(self, dat, argument=None):
|
|
|
|
@ -88,24 +88,16 @@ class DataFeeder(DataProviderConverter):
|
|
|
|
|
:type argument: swig_paddle.Arguments
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if argument is None:
|
|
|
|
|
argument = swig_paddle.Arguments.createArguments(0)
|
|
|
|
|
assert isinstance(argument, swig_paddle.Arguments)
|
|
|
|
|
argument.resize(len(self.input_types))
|
|
|
|
|
|
|
|
|
|
scanners = [
|
|
|
|
|
DataProviderConverter.create_scanner(i, each_type)
|
|
|
|
|
for i, each_type in enumerate(self.input_types)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
for each_sample in dat:
|
|
|
|
|
for name, scanner in zip(self.input_names, scanners):
|
|
|
|
|
scanner.scan(each_sample[self.reader_dict[name]])
|
|
|
|
|
|
|
|
|
|
for scanner in scanners:
|
|
|
|
|
scanner.finish_scan(argument)
|
|
|
|
|
def reorder_data(data):
|
|
|
|
|
retv = []
|
|
|
|
|
for each in data:
|
|
|
|
|
reorder = []
|
|
|
|
|
for name in self.input_names:
|
|
|
|
|
reorder.append(each[self.reader_dict[name]])
|
|
|
|
|
retv.append(reorder)
|
|
|
|
|
return retv
|
|
|
|
|
|
|
|
|
|
return argument
|
|
|
|
|
return DataProviderConverter.convert(self, reorder_data(dat), argument)
|
|
|
|
|
|
|
|
|
|
def __call__(self, dat, argument=None):
|
|
|
|
|
return self.convert(dat, argument)
|
|
|
|
|