|
|
|
@ -21,6 +21,7 @@
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include <limits>
|
|
|
|
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
|
|
|
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh"
|
|
|
|
@ -55,15 +56,15 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|
|
|
|
set_input_.insert(item);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int counter = Sampling();
|
|
|
|
|
float prob = Probability();
|
|
|
|
|
int64_t counter = Sampling();
|
|
|
|
|
S prob = Probability();
|
|
|
|
|
size_t sampled_candidates_size = num_sampled_ * sizeof(T);
|
|
|
|
|
S value = ApproximateExpectedCount(prob, num_sampled_, counter);
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size,
|
|
|
|
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync sampled_candidates failed");
|
|
|
|
|
CalUniformCandidateSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count,
|
|
|
|
|
CalUniformCandidateSampler(static_cast<int64_t>(input_size_), num_sampled_, value, true_expected_count,
|
|
|
|
|
sampled_expected_count, reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
@ -81,11 +82,11 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// getting attrs
|
|
|
|
|
num_true_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_true"));
|
|
|
|
|
num_sampled_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_sampled"));
|
|
|
|
|
num_true_ = GetAttr<int64_t>(kernel_node, "num_true");
|
|
|
|
|
num_sampled_ = GetAttr<int64_t>(kernel_node, "num_sampled");
|
|
|
|
|
unique_ = GetAttr<bool>(kernel_node, "unique");
|
|
|
|
|
range_max_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "range_max"));
|
|
|
|
|
int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
|
|
|
|
|
range_max_ = GetAttr<int64_t>(kernel_node, "range_max");
|
|
|
|
|
int64_t seed = GetAttr<int64_t>(kernel_node, "seed");
|
|
|
|
|
remove_accidental_hits_ = GetAttr<bool>(kernel_node, "remove_accidental_hits");
|
|
|
|
|
if (seed == 0) seed = time(NULL);
|
|
|
|
|
generator_.seed(seed);
|
|
|
|
@ -95,7 +96,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
input_size_ = input_shape[0] * input_shape[1];
|
|
|
|
|
if (num_sampled_ * num_true_ + static_cast<int>(input_size_) > range_max_ * num_true_) {
|
|
|
|
|
if (num_sampled_ * num_true_ + static_cast<int64_t>(input_size_) > range_max_ * num_true_) {
|
|
|
|
|
remove_accidental_hits_ = false;
|
|
|
|
|
}
|
|
|
|
|
InitSizeLists();
|
|
|
|
@ -110,13 +111,18 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|
|
|
|
output_size_list_.push_back(num_sampled_ * sizeof(S));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int Sampling() {
|
|
|
|
|
int counter = 0;
|
|
|
|
|
int tmp;
|
|
|
|
|
int picked;
|
|
|
|
|
std::set<int> set_container;
|
|
|
|
|
int64_t Sampling() {
|
|
|
|
|
int64_t counter = 0;
|
|
|
|
|
T tmp;
|
|
|
|
|
int64_t picked;
|
|
|
|
|
std::set<T> set_container;
|
|
|
|
|
// pick between [0, range_max_-1]
|
|
|
|
|
std::uniform_int_distribution<int> distribution(0, range_max_ - 1);
|
|
|
|
|
T range;
|
|
|
|
|
if (range_max_ > static_cast<int64_t>(std::numeric_limits<T>::max())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
|
|
|
|
|
}
|
|
|
|
|
range = static_cast<T>(range_max_);
|
|
|
|
|
std::uniform_int_distribution<T> distribution(0, range - 1);
|
|
|
|
|
sampled_candidates_.clear();
|
|
|
|
|
if (unique_) {
|
|
|
|
|
picked = 0;
|
|
|
|
@ -131,7 +137,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
for (int i = 0; i < num_sampled_; i++) {
|
|
|
|
|
for (int64_t i = 0; i < num_sampled_; i++) {
|
|
|
|
|
sampled_candidates_.push_back(distribution(generator_));
|
|
|
|
|
}
|
|
|
|
|
counter = num_sampled_;
|
|
|
|
@ -139,24 +145,31 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
|
|
|
|
|
return counter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
S Probability() { return static_cast<S>(1.0f / range_max_); }
|
|
|
|
|
S Probability() {
|
|
|
|
|
S range;
|
|
|
|
|
if (range_max_ > static_cast<int64_t>(std::numeric_limits<S>::max())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
|
|
|
|
|
}
|
|
|
|
|
range = static_cast<S>(range_max_);
|
|
|
|
|
return static_cast<S>(1.0f / range);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
S ApproximateExpectedCount(S p, int sampled_size, int counter) {
|
|
|
|
|
S ApproximateExpectedCount(S p, int64_t sampled_size, int64_t counter) {
|
|
|
|
|
if (sampled_size == counter) return p * sampled_size;
|
|
|
|
|
return -std::expm1(counter * std::log1p(-p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int num_true_;
|
|
|
|
|
int num_sampled_;
|
|
|
|
|
int64_t num_true_;
|
|
|
|
|
int64_t num_sampled_;
|
|
|
|
|
bool unique_;
|
|
|
|
|
int range_max_;
|
|
|
|
|
int64_t range_max_;
|
|
|
|
|
size_t input_size_;
|
|
|
|
|
bool remove_accidental_hits_;
|
|
|
|
|
std::vector<T> array_input_;
|
|
|
|
|
std::set<int> set_input_;
|
|
|
|
|
std::set<T> set_input_;
|
|
|
|
|
std::default_random_engine generator_;
|
|
|
|
|
std::vector<int> sampled_candidates_;
|
|
|
|
|
std::vector<T> sampled_candidates_;
|
|
|
|
|
std::vector<size_t> input_size_list_;
|
|
|
|
|
std::vector<size_t> output_size_list_;
|
|
|
|
|
std::vector<size_t> workspace_size_list_;
|
|
|
|
|