diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index ce732d28a7..972f0af191 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -19,8 +19,8 @@ SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. User can also define custom sampler by extending from Sampler class. """ -import mindspore._c_dataengine as cde import numpy as np +import mindspore._c_dataengine as cde class Sampler: @@ -137,6 +137,7 @@ class DistributedSampler(BuiltinSampler): self.shard_id = shard_id self.shuffle = shuffle self.seed = 0 + super().__init__() def create(self): # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle @@ -182,6 +183,7 @@ class PKSampler(BuiltinSampler): self.num_val = num_val self.shuffle = shuffle self.class_column = class_column # work for minddataset + super().__init__() def create(self): return cde.PKSampler(self.num_val, self.shuffle) @@ -192,6 +194,7 @@ class PKSampler(BuiltinSampler): but got class_column={}".format(class_column)) return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) + class RandomSampler(BuiltinSampler): """ Samples the elements randomly. @@ -225,6 +228,7 @@ class RandomSampler(BuiltinSampler): self.replacement = replacement self.num_samples = num_samples + super().__init__() def create(self): # If num_samples is not specified, then call constructor #2 @@ -275,6 +279,7 @@ class SubsetRandomSampler(BuiltinSampler): indices = [indices] self.indices = indices + super().__init__() def create(self): return cde.SubsetRandomSampler(self.indices) @@ -322,6 +327,7 @@ class WeightedRandomSampler(BuiltinSampler): self.weights = weights self.num_samples = num_samples self.replacement = replacement + super().__init__() def create(self): return cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement)