Fix dataloader when stack input data with different type (#27950)

* fix dataloader
swt-req
LielinJiang 4 years ago committed by GitHub
parent 0b733e4fd0
commit 8327accc58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,6 +17,7 @@ import six
import sys
import time
import signal
import numbers
import logging
import itertools
import threading
@ -81,12 +82,17 @@ def default_collate_fn(batch):
else:
slots[i].append(item)
if isinstance(slots[0][0], np.ndarray):
return [np.stack(slot, axis=0) for slot in slots]
elif isinstance(slots[0][0], paddle.Tensor):
return [layers.stack(slot, axis=0) for slot in slots]
else:
raise RuntimeError("Unknown data type {}".format(type(slots[0][0])))
outputs = []
for slot in slots:
if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
tmp = np.stack(slot, axis=0)
outputs.append(tmp)
elif isinstance(slot[0], paddle.Tensor):
tmp = layers.stack(slot, axis=0)
outputs.append(tmp)
else:
raise RuntimeError("Unknown data type {}".format(type(slot[0])))
return outputs
class _DatasetKind(object):

Loading…
Cancel
Save