|
|
|
@ -23,6 +23,7 @@ import numbers
|
|
|
|
|
import numpy as np
|
|
|
|
|
import mindspore._c_dataengine as cde
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
from ..core import validator_helpers as validator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
|
|
|
|
@ -349,8 +350,12 @@ class DistributedSampler(BuiltinSampler):
|
|
|
|
|
if not isinstance(shuffle, bool):
|
|
|
|
|
raise TypeError("shuffle must be a boolean value but was: {}.".format(shuffle))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None and not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
if not isinstance(offset, int):
|
|
|
|
|
raise TypeError("offset must be integer but was: {}.".format(offset))
|
|
|
|
@ -441,8 +446,12 @@ class PKSampler(BuiltinSampler):
|
|
|
|
|
if not isinstance(class_column, str):
|
|
|
|
|
raise TypeError("class_column must be a str value but was: {}.".format(class_column))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None and not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
self.num_val = num_val
|
|
|
|
|
self.shuffle = shuffle
|
|
|
|
@ -505,8 +514,12 @@ class RandomSampler(BuiltinSampler):
|
|
|
|
|
if not isinstance(replacement, bool):
|
|
|
|
|
raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None and not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
self.deterministic = False
|
|
|
|
|
self.replacement = replacement
|
|
|
|
@ -564,8 +577,12 @@ class SequentialSampler(BuiltinSampler):
|
|
|
|
|
if start_index is not None and not isinstance(start_index, int):
|
|
|
|
|
raise TypeError("start_index must be integer but was: {}.".format(start_index))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None and not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
self.start_index = start_index
|
|
|
|
|
super().__init__(num_samples)
|
|
|
|
@ -631,8 +648,12 @@ class SubsetSampler(BuiltinSampler):
|
|
|
|
|
raise TypeError("type of indices element must be number, "
|
|
|
|
|
"but got w[{}]: {}, type: {}.".format(i, item, type(item)))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None and not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
self.indices = indices
|
|
|
|
|
super().__init__(num_samples)
|
|
|
|
@ -744,8 +765,12 @@ class WeightedRandomSampler(BuiltinSampler):
|
|
|
|
|
raise TypeError("type of weights element must be number, "
|
|
|
|
|
"but got w[{}]: {}, type: {}.".format(ind, w, type(w)))
|
|
|
|
|
|
|
|
|
|
if num_samples is not None and not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples is not None:
|
|
|
|
|
if not isinstance(num_samples, int):
|
|
|
|
|
raise TypeError("num_samples must be integer but was: {}.".format(num_samples))
|
|
|
|
|
if num_samples < 0 or num_samples > validator.INT64_MAX:
|
|
|
|
|
raise ValueError("num_samples exceeds the boundary between {} and {}(INT64_MAX)!"
|
|
|
|
|
.format(0, validator.INT64_MAX))
|
|
|
|
|
|
|
|
|
|
if not isinstance(replacement, bool):
|
|
|
|
|
raise TypeError("replacement must be a boolean value but was: {}.".format(replacement))
|
|
|
|
|