From 25ff4c9312368883709ba5eeacff651323c5a9ca Mon Sep 17 00:00:00 2001 From: luoyang Date: Wed, 10 Mar 2021 21:25:32 +0800 Subject: [PATCH] Fix source len is not divisible by batch_size in user defined sampler --- mindspore/dataset/engine/samplers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 748e20dda9..1774d26e5e 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -237,6 +237,7 @@ class Sampler(BuiltinSampler): # Indices fetcher # Do not override this method! + # pylint: disable=missing-docstring def _get_indices(self): sampler_iter = iter(self) ret = [] @@ -246,7 +247,10 @@ class Sampler(BuiltinSampler): ret.append(idx) except StopIteration: break - return np.array(ret) + indices = np.array(ret) + if indices.dtype == object: + raise RuntimeError("Fetched indices can not be converted to a valid ndarray.") + return indices # Instance fetcher # Do not override this method!