|
|
|
@ -21,7 +21,7 @@ import threading
|
|
|
|
|
import paddle
|
|
|
|
|
from .framework import Program, Variable, program_guard, default_main_program, default_startup_program, in_dygraph_mode, cpu_places
|
|
|
|
|
from .executor import global_scope
|
|
|
|
|
from .data_feeder import DataFeeder, BatchedTensorProvider, ListTensorProvider
|
|
|
|
|
from .data_feeder import DataFeeder, BatchedTensorProvider, DygraphListTensorProvider
|
|
|
|
|
from .layers.io import monkey_patch_reader_methods, _copy_reader_var_, double_buffer
|
|
|
|
|
from .unique_name import UniqueNameGenerator
|
|
|
|
|
import logging
|
|
|
|
@ -442,14 +442,13 @@ class GeneratorLoader(DataLoaderBase):
|
|
|
|
|
|
|
|
|
|
def __next__(self):
|
|
|
|
|
try:
|
|
|
|
|
if not in_dygraph_mode():
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
return self._reader.read_next_var_list()
|
|
|
|
|
else:
|
|
|
|
|
if self._return_list:
|
|
|
|
|
return self._reader.read_next_list()
|
|
|
|
|
else:
|
|
|
|
|
return self._reader.read_next()
|
|
|
|
|
else:
|
|
|
|
|
ret = self._reader.read_next_list()[0]
|
|
|
|
|
return [dygraph.base.to_variable(np.array(v)) for v in ret]
|
|
|
|
|
except StopIteration:
|
|
|
|
|
self._queue.close()
|
|
|
|
|
self._reset()
|
|
|
|
@ -517,7 +516,12 @@ class GeneratorLoader(DataLoaderBase):
|
|
|
|
|
drop_last=True,
|
|
|
|
|
places=None):
|
|
|
|
|
assert batch_size > 0, "batch_size must be larger than 0"
|
|
|
|
|
if not in_dygraph_mode():
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
self.set_sample_list_generator(
|
|
|
|
|
paddle.batch(
|
|
|
|
|
reader, batch_size=batch_size, drop_last=drop_last),
|
|
|
|
|
places=places)
|
|
|
|
|
else:
|
|
|
|
|
has_lod = False
|
|
|
|
|
for f in self._feed_list:
|
|
|
|
|
if f.lod_level != 0:
|
|
|
|
@ -537,15 +541,16 @@ class GeneratorLoader(DataLoaderBase):
|
|
|
|
|
generator=reader,
|
|
|
|
|
drop_last=drop_last)
|
|
|
|
|
self.set_batch_generator(reader, places=places)
|
|
|
|
|
else:
|
|
|
|
|
self.set_sample_list_generator(
|
|
|
|
|
paddle.batch(
|
|
|
|
|
reader, batch_size=batch_size, drop_last=drop_last),
|
|
|
|
|
places=places)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def set_sample_list_generator(self, reader, places=None):
|
|
|
|
|
if not in_dygraph_mode():
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
provider = DygraphListTensorProvider(reader, places)
|
|
|
|
|
|
|
|
|
|
def __tensor_reader_impl__():
|
|
|
|
|
for slots in provider():
|
|
|
|
|
yield slots[0]
|
|
|
|
|
else:
|
|
|
|
|
with program_guard(Program(), Program()):
|
|
|
|
|
feeder = DataFeeder(
|
|
|
|
|
feed_list=self._feed_list, place=core.CPUPlace())
|
|
|
|
@ -555,12 +560,6 @@ class GeneratorLoader(DataLoaderBase):
|
|
|
|
|
def __tensor_reader_impl__():
|
|
|
|
|
for slots in paddle_reader():
|
|
|
|
|
yield [slots[var.name] for var in self._feed_list]
|
|
|
|
|
else:
|
|
|
|
|
provider = ListTensorProvider(reader, places)
|
|
|
|
|
|
|
|
|
|
def __tensor_reader_impl__():
|
|
|
|
|
for slots in provider():
|
|
|
|
|
yield slots[0]
|
|
|
|
|
|
|
|
|
|
self.set_batch_generator(__tensor_reader_impl__, places)
|
|
|
|
|
return self
|
|
|
|
@ -571,8 +570,8 @@ class GeneratorLoader(DataLoaderBase):
|
|
|
|
|
assert places is not None, "Places cannot be None when DataLoader is iterable"
|
|
|
|
|
self._places = _convert_places(places)
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
assert len(self._places
|
|
|
|
|
) == 1, "Number of places must be 1 in dygraph mode"
|
|
|
|
|
assert len(self._places) == 1, \
|
|
|
|
|
"Number of places must be 1 in dygraph mode"
|
|
|
|
|
else:
|
|
|
|
|
if places is not None:
|
|
|
|
|
logging.info(
|
|
|
|
|