Speedup reader check_input_array when item is array (#25395)

fix_copy_if_different
WangXi 5 years ago committed by GitHub
parent 52be62c5ae
commit 1c7215ace4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -97,6 +97,18 @@ class DataLoaderBase(object):
def __next__(self):
raise NotImplementedError()
@classmethod
def _check_input_array(cls, item):
arr = np.asarray(item)
if arr.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
return arr
class DataLoader(object):
"""
@ -807,17 +819,6 @@ class DygraphGeneratorLoader(DataLoaderBase):
self._reset()
six.reraise(*sys.exc_info())
@classmethod
def _check_input_array(cls, item):
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError(
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor.")
def _exit_thread_expectedly(self):
self._thread_done_event.set()
self._blocking_queue.close()
@ -894,7 +895,7 @@ class DygraphGeneratorLoader(DataLoaderBase):
array = core.LoDTensorArray()
for item in sample:
if not isinstance(item, core.LoDTensor):
self._check_input_array(item)
item = self._check_input_array(item)
tmp = core.LoDTensor()
tmp.set(item, core.CPUPlace())
item = tmp
@ -1115,19 +1116,6 @@ class GeneratorLoader(DataLoaderBase):
assert not self._iterable, "reset() cannot be called when DataLoader is iterable"
self._reset()
@classmethod
def _check_input_array(cls, item):
arr = np.array(item)
if arr.dtype == np.object:
raise TypeError((
"\n\tFaild to convert input data to a regular ndarray :\n\t* Usually "
"this means the input data contains nested lists with different lengths. "
"\n\t* Check the reader function passed to 'decorate_batch_generator'"
" to locate the data causes this issue.\n\t* Please consider using "
"'fluid.create_lod_tensor' to convert it to a LoD-Tensor."))
return arr
def _start(self):
def __thread_main__():
try:

Loading…
Cancel
Save