|
|
|
@ -50,11 +50,15 @@ Status WeightedRandomSampler::InitSampler() {
|
|
|
|
|
std::to_string(samples_per_buffer_) + ".\n");
|
|
|
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) {
|
|
|
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
|
|
|
|
"Invalid parameter, number of samples weights is more than num of rows. "
|
|
|
|
|
"Might generate id out of bound OR other errors");
|
|
|
|
|
"Invalid parameter, size of sample weights must be less than or equal to num of data, "
|
|
|
|
|
"otherwise might cause generated id out of bound or other errors, but got weight size: " +
|
|
|
|
|
std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
|
|
|
|
|
}
|
|
|
|
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("Invalid parameter, without replacement, weights size must be greater than num_samples.");
|
|
|
|
|
RETURN_STATUS_UNEXPECTED(
|
|
|
|
|
"Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, "
|
|
|
|
|
"but got weight size: " +
|
|
|
|
|
std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Initialize random generator with seed from config manager
|
|
|
|
@ -110,11 +114,16 @@ Status WeightedRandomSampler::ResetSampler() {
|
|
|
|
|
Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
|
|
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) {
|
|
|
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
|
|
|
|
"number of samples weights is more than num of rows. Might generate id out of bound OR other errors");
|
|
|
|
|
"Invalid parameter, size of sample weights must be less than or equal to num of data, "
|
|
|
|
|
"otherwise might cause generated id out of bound or other errors, but got weight size: " +
|
|
|
|
|
std::to_string(weights_.size()) + ", num of data: " + std::to_string(num_rows_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples");
|
|
|
|
|
RETURN_STATUS_UNEXPECTED(
|
|
|
|
|
"Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, "
|
|
|
|
|
"but got weight size: " +
|
|
|
|
|
std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (sample_id_ == num_samples_) {
|
|
|
|
@ -150,7 +159,8 @@ Status WeightedRandomSampler::GetNextSample(std::unique_ptr<DataBuffer> *out_buf
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (genId >= num_rows_) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("generated id is bigger than numRows (out of bound).");
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("Generated indice is out of bound, expect range [0, num_data-1], got indice: " +
|
|
|
|
|
std::to_string(genId) + ", num_data: " + std::to_string(num_rows_ - 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (HasChildSampler()) {
|
|
|
|
|