Fix source len is not divisible by batch_size in user defined sampler

pull/13130/head
luoyang 4 years ago
parent 9134e7b2d4
commit 25ff4c9312

@ -237,6 +237,7 @@ class Sampler(BuiltinSampler):
# Indices fetcher
# Do not override this method!
# pylint: disable=missing-docstring
def _get_indices(self):
sampler_iter = iter(self)
ret = []
@ -246,7 +247,10 @@ class Sampler(BuiltinSampler):
ret.append(idx)
except StopIteration:
break
return np.array(ret)
indices = np.array(ret)
if indices.dtype == object:
raise RuntimeError("Fetched indices can not be converted to a valid ndarray.")
return indices
# Instance fetcher
# Do not override this method!

Loading…
Cancel
Save