!9750 add int64 support to UniformCandidateSampler GPU

From: @TFbunny
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @tom__chen
pull/9750/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b401a5fb07

@ -17,20 +17,20 @@
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh"
template <typename S>
__global__ void AssignToOutput(const int size, const S prob_val, S *output_array) {
__global__ void AssignToOutput(const int64_t size, const S prob_val, S *output_array) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output_array[pos] = prob_val;
}
}
template <typename S>
void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream) {
void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, const S prob_val,
S *true_expected_count, S *sampled_expected_count, cudaStream_t cuda_stream) {
AssignToOutput<<<GET_BLOCKS(true_size), GET_THREADS, 0, cuda_stream>>>(true_size, prob_val, true_expected_count);
AssignToOutput<<<GET_BLOCKS(num_sampled), GET_THREADS, 0, cuda_stream>>>(num_sampled, prob_val,
sampled_expected_count);
}
template void CalUniformCandidateSampler<float>(const int true_size, const int num_sampled, const float prob_val,
float *true_expected_count, float *sampled_expected_count,
cudaStream_t cuda_stream);
template void CalUniformCandidateSampler<float>(const int64_t true_size, const int64_t num_sampled,
const float prob_val, float *true_expected_count,
float *sampled_expected_count, cudaStream_t cuda_stream);

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename S>
void CalUniformCandidateSampler(const int true_size, const int num_sampled, const S prob_val, S *true_expected_count,
S *sampled_expected_count, cudaStream_t cuda_stream);
void CalUniformCandidateSampler(const int64_t true_size, const int64_t num_sampled, const S prob_val,
S *true_expected_count, S *sampled_expected_count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_UNIFORM_CANDIDATE_SAMPLER_IMPL_CUH_

@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler,
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
UniformCandidateSamplerGpuKernel, int, float)
MS_REG_GPU_KERNEL_TWO(UniformCandidateSampler,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
UniformCandidateSamplerGpuKernel, int64_t, float)
} // namespace kernel
} // namespace mindspore

@ -21,6 +21,7 @@
#include <set>
#include <vector>
#include <random>
#include <limits>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/uniform_candidate_sampler_impl.cuh"
@ -55,15 +56,15 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
set_input_.insert(item);
}
}
int counter = Sampling();
float prob = Probability();
int64_t counter = Sampling();
S prob = Probability();
size_t sampled_candidates_size = num_sampled_ * sizeof(T);
S value = ApproximateExpectedCount(prob, num_sampled_, counter);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(sampled_candidates, &sampled_candidates_[0], sampled_candidates_size,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync sampled_candidates failed");
CalUniformCandidateSampler(static_cast<int>(input_size_), num_sampled_, value, true_expected_count,
CalUniformCandidateSampler(static_cast<int64_t>(input_size_), num_sampled_, value, true_expected_count,
sampled_expected_count, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -81,11 +82,11 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
return false;
}
// getting attrs
num_true_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_true"));
num_sampled_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_sampled"));
num_true_ = GetAttr<int64_t>(kernel_node, "num_true");
num_sampled_ = GetAttr<int64_t>(kernel_node, "num_sampled");
unique_ = GetAttr<bool>(kernel_node, "unique");
range_max_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "range_max"));
int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
range_max_ = GetAttr<int64_t>(kernel_node, "range_max");
int64_t seed = GetAttr<int64_t>(kernel_node, "seed");
remove_accidental_hits_ = GetAttr<bool>(kernel_node, "remove_accidental_hits");
if (seed == 0) seed = time(NULL);
generator_.seed(seed);
@ -95,7 +96,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
return false;
}
input_size_ = input_shape[0] * input_shape[1];
if (num_sampled_ * num_true_ + static_cast<int>(input_size_) > range_max_ * num_true_) {
if (num_sampled_ * num_true_ + static_cast<int64_t>(input_size_) > range_max_ * num_true_) {
remove_accidental_hits_ = false;
}
InitSizeLists();
@ -110,13 +111,18 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
output_size_list_.push_back(num_sampled_ * sizeof(S));
}
int Sampling() {
int counter = 0;
int tmp;
int picked;
std::set<int> set_container;
int64_t Sampling() {
int64_t counter = 0;
T tmp;
int64_t picked;
std::set<T> set_container;
// pick between [0, range_max_-1]
std::uniform_int_distribution<int> distribution(0, range_max_ - 1);
T range;
if (range_max_ > static_cast<int64_t>(std::numeric_limits<T>::max())) {
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
}
range = static_cast<T>(range_max_);
std::uniform_int_distribution<T> distribution(0, range - 1);
sampled_candidates_.clear();
if (unique_) {
picked = 0;
@ -131,7 +137,7 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
}
}
} else {
for (int i = 0; i < num_sampled_; i++) {
for (int64_t i = 0; i < num_sampled_; i++) {
sampled_candidates_.push_back(distribution(generator_));
}
counter = num_sampled_;
@ -139,24 +145,31 @@ class UniformCandidateSamplerGpuKernel : public GpuKernel {
return counter;
}
S Probability() { return static_cast<S>(1.0f / range_max_); }
S Probability() {
S range;
if (range_max_ > static_cast<int64_t>(std::numeric_limits<S>::max())) {
MS_LOG(EXCEPTION) << "range_max_ failed to cast";
}
range = static_cast<S>(range_max_);
return static_cast<S>(1.0f / range);
}
S ApproximateExpectedCount(S p, int sampled_size, int counter) {
S ApproximateExpectedCount(S p, int64_t sampled_size, int64_t counter) {
if (sampled_size == counter) return p * sampled_size;
return -std::expm1(counter * std::log1p(-p));
}
private:
int num_true_;
int num_sampled_;
int64_t num_true_;
int64_t num_sampled_;
bool unique_;
int range_max_;
int64_t range_max_;
size_t input_size_;
bool remove_accidental_hits_;
std::vector<T> array_input_;
std::set<int> set_input_;
std::set<T> set_input_;
std::default_random_engine generator_;
std::vector<int> sampled_candidates_;
std::vector<T> sampled_candidates_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

@ -583,7 +583,9 @@ class UniformCandidateSampler(PrimitiveWithInfer):
self.num_sampled = num_sampled
def infer_dtype(self, true_classes_type):
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, (mstype.int32), self.name)
Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type,
(mstype.int32, mstype.int64), self.name)
return (true_classes_type, mstype.float32, mstype.float32)
def infer_shape(self, true_classes_shape):

@ -39,6 +39,14 @@ def uniform_candidate_sampler(x, num_true, num_sampled, unique, range_max):
out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int32)))
return out1.shape, out2.shape, out3.shape
def uniform_candidate_sampler_int64(x, num_true, num_sampled, unique, range_max):
uniform_candidate_sampler_net = UniformCandidateSamplerNet(num_true,
num_sampled,
unique,
range_max)
out1, out2, out3 = uniform_candidate_sampler_net(Tensor(x.astype(np.int64)))
return out1.shape, out2.shape, out3.shape
class UniformCandidateSamplerHitNet(nn.Cell):
def __init__(self, num_true, num_sampled, unique, range_max, seed, remove_accidental_hits):
@ -155,6 +163,19 @@ def test_uniform_candidate_sampler_large_random():
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_uniform_candidate_sampler_large_random_int64_input():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
ms1, ms2, ms3 = uniform_candidate_sampler_int64(np.arange(2142).reshape(34, 63),
63, 10, False, 12)
expected_1 = (10,)
expected_2 = (34, 63)
expected_3 = (10,)
np.testing.assert_array_equal(ms1, expected_1)
np.testing.assert_array_equal(ms2, expected_2)
np.testing.assert_array_equal(ms3, expected_3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training

Loading…
Cancel
Save