Fix dataloader when stack input data with different type (#27950)

* fix dataloader
swt-req
LielinJiang 5 years ago committed by GitHub
parent 0b733e4fd0
commit 8327accc58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,6 +17,7 @@ import six
import sys import sys
import time import time
import signal import signal
import numbers
import logging import logging
import itertools import itertools
import threading import threading
@ -81,12 +82,17 @@ def default_collate_fn(batch):
else: else:
slots[i].append(item) slots[i].append(item)
if isinstance(slots[0][0], np.ndarray): outputs = []
return [np.stack(slot, axis=0) for slot in slots] for slot in slots:
elif isinstance(slots[0][0], paddle.Tensor): if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
return [layers.stack(slot, axis=0) for slot in slots] tmp = np.stack(slot, axis=0)
outputs.append(tmp)
elif isinstance(slot[0], paddle.Tensor):
tmp = layers.stack(slot, axis=0)
outputs.append(tmp)
else: else:
raise RuntimeError("Unknown data type {}".format(type(slots[0][0]))) raise RuntimeError("Unknown data type {}".format(type(slot[0])))
return outputs
class _DatasetKind(object): class _DatasetKind(object):

Loading…
Cancel
Save