diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index cee38dfa35..56ef705f60 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -506,6 +506,9 @@ class SubsetRandomSampler(BuiltinSampler): def get_num_samples(self): num_samples = super().get_num_samples() + if num_samples is None: + return len(self.indices) + return min(len(self.indices), num_samples)