|
|
|
@ -97,6 +97,18 @@ class DataLoaderBase(object):
|
|
|
|
|
def __next__(self):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _check_input_array(cls, item):
|
|
|
|
|
arr = np.asarray(item)
|
|
|
|
|
if arr.dtype == np.object:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
|
|
|
|
|
"this means the input data contains nested lists with different lengths. "
|
|
|
|
|
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
|
|
|
|
|
" to locate the data causes this issue.\n\t* Please consider using "
|
|
|
|
|
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataLoader(object):
|
|
|
|
|
"""
|
|
|
|
@ -807,17 +819,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
|
|
|
|
|
self._reset()
|
|
|
|
|
six.reraise(*sys.exc_info())
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _check_input_array(cls, item):
|
|
|
|
|
arr = np.array(item)
|
|
|
|
|
if arr.dtype == np.object:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
|
|
|
|
|
"this means the input data contains nested lists with different lengths. "
|
|
|
|
|
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
|
|
|
|
|
" to locate the data causes this issue.\n\t* Please consider using "
|
|
|
|
|
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
|
|
|
|
|
|
|
|
|
|
def _exit_thread_expectedly(self):
|
|
|
|
|
self._thread_done_event.set()
|
|
|
|
|
self._blocking_queue.close()
|
|
|
|
@ -894,7 +895,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
|
|
|
|
|
array = core.LoDTensorArray()
|
|
|
|
|
for item in sample:
|
|
|
|
|
if not isinstance(item, core.LoDTensor):
|
|
|
|
|
self._check_input_array(item)
|
|
|
|
|
item = self._check_input_array(item)
|
|
|
|
|
tmp = core.LoDTensor()
|
|
|
|
|
tmp.set(item, core.CPUPlace())
|
|
|
|
|
item = tmp
|
|
|
|
@ -1115,19 +1116,6 @@ class GeneratorLoader(DataLoaderBase):
|
|
|
|
|
assert not self._iterable, "reset() cannot be called when DataLoader is iterable"
|
|
|
|
|
self._reset()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def _check_input_array(cls, item):
|
|
|
|
|
arr = np.array(item)
|
|
|
|
|
if arr.dtype == np.object:
|
|
|
|
|
raise TypeError((
|
|
|
|
|
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
|
|
|
|
|
"this means the input data contains nested lists with different lengths. "
|
|
|
|
|
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
|
|
|
|
|
" to locate the data causes this issue.\n\t* Please consider using "
|
|
|
|
|
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
|
|
|
|
|
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
def _start(self):
|
|
|
|
|
def __thread_main__():
|
|
|
|
|
try:
|
|
|
|
|