|
|
|
@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler):
|
|
|
|
|
num_val (int): Number of elements to sample for each class.
|
|
|
|
|
num_class (int, optional): Number of classes to sample (default=None, all classes).
|
|
|
|
|
shuffle (bool, optional): If true, the class IDs are shuffled (default=False).
|
|
|
|
|
class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> import mindspore.dataset as ds
|
|
|
|
@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler):
|
|
|
|
|
ValueError: If shuffle is not boolean.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, num_val, num_class=None, shuffle=False):
|
|
|
|
|
def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'):
|
|
|
|
|
if num_val <= 0:
|
|
|
|
|
raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val))
|
|
|
|
|
|
|
|
|
@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler):
|
|
|
|
|
|
|
|
|
|
self.num_val = num_val
|
|
|
|
|
self.shuffle = shuffle
|
|
|
|
|
self.class_column = class_column # work for minddataset
|
|
|
|
|
|
|
|
|
|
def create(self):
|
|
|
|
|
return cde.PKSampler(self.num_val, self.shuffle)
|
|
|
|
|
|
|
|
|
|
def _create_for_minddataset(self):
|
|
|
|
|
return cde.MindrecordPkSampler(self.num_val, self.shuffle)
|
|
|
|
|
if not self.class_column or not isinstance(self.class_column, str):
|
|
|
|
|
raise ValueError("class_column should be a not empty string value, \
|
|
|
|
|
but got class_column={}".format(class_column))
|
|
|
|
|
return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle)
|
|
|
|
|
|
|
|
|
|
class RandomSampler(BuiltinSampler):
|
|
|
|
|
"""
|
|
|
|
|