|
|
|
@ -87,8 +87,7 @@ std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index = 0,
|
|
|
|
|
/// \param[in] indices - A vector sequence of indices.
|
|
|
|
|
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
|
|
|
|
/// \return Shared pointer to the current Sampler.
|
|
|
|
|
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices,
|
|
|
|
|
int64_t num_samples = 0);
|
|
|
|
|
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(std::vector<int64_t> indices, int64_t num_samples = 0);
|
|
|
|
|
|
|
|
|
|
/// Function to create a Weighted Random Sampler.
|
|
|
|
|
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
|
|
|
|
@ -97,8 +96,8 @@ std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<in
|
|
|
|
|
/// \param[in] num_samples - The number of samples to draw (default to all elements).
|
|
|
|
|
/// \param[in] replacement - If True, put the sample ID back for the next draw.
|
|
|
|
|
/// \return Shared pointer to the current Sampler.
|
|
|
|
|
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights,
|
|
|
|
|
int64_t num_samples = 0, bool replacement = true);
|
|
|
|
|
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(std::vector<double> weights, int64_t num_samples = 0,
|
|
|
|
|
bool replacement = true);
|
|
|
|
|
|
|
|
|
|
/* ####################################### Derived Sampler classes ################################# */
|
|
|
|
|
class DistributedSamplerObj : public SamplerObj {
|
|
|
|
@ -169,7 +168,7 @@ class SequentialSamplerObj : public SamplerObj {
|
|
|
|
|
|
|
|
|
|
class SubsetRandomSamplerObj : public SamplerObj {
|
|
|
|
|
public:
|
|
|
|
|
SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples);
|
|
|
|
|
SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples);
|
|
|
|
|
|
|
|
|
|
~SubsetRandomSamplerObj() = default;
|
|
|
|
|
|
|
|
|
@ -178,14 +177,13 @@ class SubsetRandomSamplerObj : public SamplerObj {
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const std::vector<int64_t> &indices_;
|
|
|
|
|
const std::vector<int64_t> indices_;
|
|
|
|
|
int64_t num_samples_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class WeightedRandomSamplerObj : public SamplerObj {
|
|
|
|
|
public:
|
|
|
|
|
explicit WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples = 0,
|
|
|
|
|
bool replacement = true);
|
|
|
|
|
explicit WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples = 0, bool replacement = true);
|
|
|
|
|
|
|
|
|
|
~WeightedRandomSamplerObj() = default;
|
|
|
|
|
|
|
|
|
@ -194,7 +192,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
|
|
|
|
|
bool ValidateParams() override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const std::vector<double> &weights_;
|
|
|
|
|
const std::vector<double> weights_;
|
|
|
|
|
int64_t num_samples_;
|
|
|
|
|
bool replacement_;
|
|
|
|
|
};
|
|
|
|
|