|
|
|
@ -49,16 +49,6 @@ Status WeightedRandomSampler::InitSampler() {
|
|
|
|
|
"Invalid parameter, samples_per_buffer must be greater than 0, but got " +
|
|
|
|
|
std::to_string(samples_per_buffer_) + ".\n");
|
|
|
|
|
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(weights_.size() != 0, "Invalid parameter, weights size must not be 0.\n");
|
|
|
|
|
int32_t zero_elem = 0;
|
|
|
|
|
for (auto &elem : weights_) {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(elem >= 0.0, "Invalid parameter, weights must not contain negative number, but got " +
|
|
|
|
|
std::to_string(elem) + ".\n");
|
|
|
|
|
if (elem == 0.0) zero_elem++;
|
|
|
|
|
}
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(zero_elem != weights_.size(),
|
|
|
|
|
"Invalid parameter, elements of weights must not be all zero.\n");
|
|
|
|
|
|
|
|
|
|
if (weights_.size() > static_cast<size_t>(num_rows_)) {
|
|
|
|
|
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
|
|
|
|
"Invalid parameter, size of sample weights must be less than or equal to num of data, "
|
|
|
|
|