|
|
|
@ -17,6 +17,7 @@ import six
|
|
|
|
|
import sys
|
|
|
|
|
import time
|
|
|
|
|
import signal
|
|
|
|
|
import numbers
|
|
|
|
|
import logging
|
|
|
|
|
import itertools
|
|
|
|
|
import threading
|
|
|
|
@ -81,12 +82,17 @@ def default_collate_fn(batch):
|
|
|
|
|
else:
|
|
|
|
|
slots[i].append(item)
|
|
|
|
|
|
|
|
|
|
if isinstance(slots[0][0], np.ndarray):
|
|
|
|
|
return [np.stack(slot, axis=0) for slot in slots]
|
|
|
|
|
elif isinstance(slots[0][0], paddle.Tensor):
|
|
|
|
|
return [layers.stack(slot, axis=0) for slot in slots]
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Unknown data type {}".format(type(slots[0][0])))
|
|
|
|
|
outputs = []
|
|
|
|
|
for slot in slots:
|
|
|
|
|
if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
|
|
|
|
|
tmp = np.stack(slot, axis=0)
|
|
|
|
|
outputs.append(tmp)
|
|
|
|
|
elif isinstance(slot[0], paddle.Tensor):
|
|
|
|
|
tmp = layers.stack(slot, axis=0)
|
|
|
|
|
outputs.append(tmp)
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Unknown data type {}".format(type(slot[0])))
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetKind(object):
|
|
|
|
|