!9519 Add dtype supports to op random_categorical on gpu

From: @yuan_shen_zhou
Reviewed-by: 
Signed-off-by:
pull/9519/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ddecc02b0e

@ -17,15 +17,15 @@
#include "backend/kernel_compiler/gpu/cuda_impl/random_categorical.cuh"
template <typename S>
__global__ void RandomCategorical(int num_samples, double** dev_rand, double** dev_cdf,
int batch_size, int num_classes, S *output_addr) {
int size = num_samples * batch_size;
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) {
int cur_row = pos / num_samples;
int cur_col = pos % num_samples;
__global__ void RandomCategorical(const size_t num_samples, double** dev_rand, double** dev_cdf,
const size_t batch_size, const size_t num_classes, S *output_addr) {
size_t size = num_samples * batch_size;
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) {
size_t cur_row = pos / num_samples;
size_t cur_col = pos % num_samples;
const double to_find = dev_cdf[cur_row][num_classes-1] * dev_rand[cur_row][cur_col];
int idx = 0;
size_t idx = 0;
while (dev_cdf[cur_row][idx] < to_find) {
idx++;
}
@ -34,22 +34,22 @@ __global__ void RandomCategorical(int num_samples, double** dev_rand, double** d
}
template <typename T>
__global__ void GetCdf(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes) {
int size = num_classes * batch_size;
for (int pos= blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) {
int cur_row = pos / num_classes;
int cur_col = pos % num_classes;
__global__ void GetCdf(const T *logits_addr, double** dev_cdf, const size_t batch_size, const size_t num_classes) {
size_t size = num_classes * batch_size;
for (size_t pos= blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += gridDim.x * blockDim.x) {
size_t cur_row = pos / num_classes;
size_t cur_col = pos % num_classes;
if (cur_col != 0) {
return;
}
T max_of_row = logits_addr[pos];
for (int i = 1; i < num_classes; i++) {
for (size_t i = 1; i < num_classes; i++) {
if (logits_addr[pos + i] > max_of_row) {
max_of_row = logits_addr[pos + i];
}
}
dev_cdf[cur_row][0] = exp(static_cast<double>(logits_addr[pos] - max_of_row));
for (int i = 1; i < num_classes; i++) {
for (size_t i = 1; i < num_classes; i++) {
double tmp = exp(static_cast<double>(logits_addr[pos + i] - max_of_row));
dev_cdf[cur_row][i] = dev_cdf[cur_row][i - 1] + tmp;
}
@ -57,34 +57,34 @@ __global__ void GetCdf(const T *logits_addr, double** dev_cdf, const int batch_s
}
template <typename S>
void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf, int batch_size,
int num_classes, S *output_addr, cudaStream_t cuda_stream) {
int size_out = num_samples * batch_size;
void RandomCategoricalKernel(const size_t num_samples, double** dev_rand, double** dev_cdf,
const size_t batch_size, const size_t num_classes, S *output_addr, cudaStream_t cuda_stream) {
size_t size_out = num_samples * batch_size;
RandomCategorical<<<GET_BLOCKS(size_out), GET_THREADS, 0, cuda_stream>>>(num_samples, dev_rand,
dev_cdf, batch_size,
num_classes, output_addr);
}
template <typename T>
void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes,
void GetCdfKernel(const T *logits_addr, double** dev_cdf, const size_t batch_size, const size_t num_classes,
cudaStream_t cuda_stream) {
int size_cdf = num_classes * batch_size;
size_t size_cdf = num_classes * batch_size;
GetCdf<<<GET_BLOCKS(size_cdf), GET_THREADS, 0, cuda_stream>>>(logits_addr, dev_cdf, batch_size, num_classes);
}
template void GetCdfKernel<half>(const half *logits_addr, double** dev_cdf, const int batch_size,
const int num_classes, cudaStream_t cuda_stream);
template void GetCdfKernel<float>(const float *logits_addr, double** dev_cdf, const int batch_size,
const int num_classes, cudaStream_t cuda_stream);
template void GetCdfKernel<double>(const double *logits_addr, double** dev_cdf, const int batch_size,
const int num_classes, cudaStream_t cuda_stream);
template void GetCdfKernel<half>(const half *logits_addr, double** dev_cdf, const size_t batch_size,
const size_t num_classes, cudaStream_t cuda_stream);
template void GetCdfKernel<float>(const float *logits_addr, double** dev_cdf, const size_t batch_size,
const size_t num_classes, cudaStream_t cuda_stream);
template void GetCdfKernel<double>(const double *logits_addr, double** dev_cdf, const size_t batch_size,
const size_t num_classes, cudaStream_t cuda_stream);
template void RandomCategoricalKernel<int16_t>(int num_samples,
double** dev_rand, double** dev_cdf, int batch_size, int num_classes,
template void RandomCategoricalKernel<int16_t>(const size_t num_samples,
double** dev_rand, double** dev_cdf, const size_t batch_size, const size_t num_classes,
int16_t *output_addr, cudaStream_t cuda_stream);
template void RandomCategoricalKernel<int>(int num_samples,
double** dev_rand, double** dev_cdf, int batch_size, int num_classes,
template void RandomCategoricalKernel<int>(const size_t num_samples,
double** dev_rand, double** dev_cdf, const size_t batch_size, const size_t num_classes,
int *output_addr, cudaStream_t cuda_stream);
template void RandomCategoricalKernel<int64_t>(int num_samples,
double** dev_rand, double** dev_cdf, int batch_size, int num_classes,
template void RandomCategoricalKernel<int64_t>(const size_t num_samples,
double** dev_rand, double** dev_cdf, const size_t batch_size, const size_t num_classes,
int64_t *output_addr, cudaStream_t cuda_stream);

@ -19,11 +19,11 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void GetCdfKernel(const T *logits_addr, double** dev_cdf, const int batch_size, const int num_classes,
void GetCdfKernel(const T *logits_addr, double** dev_cdf, const size_t batch_size, const size_t num_classes,
cudaStream_t cuda_stream);
template <typename S>
void RandomCategoricalKernel(int num_samples, double** dev_rand, double** dev_cdf,
int batch_size, int num_classes, S *output_addr,
void RandomCategoricalKernel(const size_t num_samples, double** dev_rand, double** dev_cdf,
const size_t batch_size, const size_t num_classes, S *output_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RANDOM_CATEGORICAL_IMPL_H_

@ -18,68 +18,132 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, half, int16_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, half, int32_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, half, int64_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, float, int16_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, float, int32_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, float, int64_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, double, int16_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, double, int32_t)
MS_REG_GPU_KERNEL_TWO(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, double, int64_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, half, int, int16_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, half, int, int32_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, half, int, int64_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, float, int, int16_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, float, int, int32_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, float, int, int64_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, double, int, int16_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, double, int, int32_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, double, int, int64_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, half, int64_t, int16_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, half, int64_t, int32_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, half, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, float, int64_t, int16_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, float, int64_t, int32_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, float, int64_t, int64_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt16),
RandomCategoricalGpuKernel, double, int64_t, int16_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt32),
RandomCategoricalGpuKernel, double, int64_t, int32_t)
MS_REG_GPU_KERNEL_THREE(RandomCategorical,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
RandomCategoricalGpuKernel, double, int64_t, int64_t)
} // namespace kernel
} // namespace mindspore

@ -26,7 +26,7 @@
namespace mindspore {
namespace kernel {
template <typename T, typename S>
template <typename T, typename G, typename S>
class RandomCategoricalGpuKernel : public GpuKernel {
public:
RandomCategoricalGpuKernel() : batch_size_(0), num_classes_(0), num_samples_(0), seed_(0) {}
@ -43,7 +43,7 @@ class RandomCategoricalGpuKernel : public GpuKernel {
std::unique_ptr<double *[]> host_cdf;
host_cdf = std::make_unique<double *[]>(batch_size_);
for (int i = 0; i < batch_size_; i++) {
for (size_t i = 0; i < batch_size_; i++) {
host_cdf[i] = GetDeviceAddress<double>(workspaces, i);
}
double **dev_cdf = GetDeviceAddress<double *>(workspaces, batch_size_);
@ -55,18 +55,18 @@ class RandomCategoricalGpuKernel : public GpuKernel {
std::unique_ptr<double *[]> host_rand;
host_rand = std::make_unique<double *[]>(batch_size_);
for (int i = 0; i < batch_size_; i++) {
for (size_t i = 0; i < batch_size_; i++) {
host_rand[i] = GetDeviceAddress<double>(workspaces, batch_size_ + 1 + i);
}
double **dev_rand = GetDeviceAddress<double *>(workspaces, batch_size_ * 2 + 1);
for (int i = 0; i < batch_size_; i++) {
for (size_t i = 0; i < batch_size_; i++) {
std::unique_ptr<double[]> host_1d_rand;
host_1d_rand = std::make_unique<double[]>(num_samples_);
std::default_random_engine rng(seed_);
std::default_random_engine rng(static_cast<G>(seed_));
std::uniform_real_distribution<> dist(0, 1);
for (int j = 0; j < num_samples_; j++) {
for (size_t j = 0; j < num_samples_; j++) {
host_1d_rand[j] = dist(rng);
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
@ -105,11 +105,11 @@ class RandomCategoricalGpuKernel : public GpuKernel {
MS_LOG(ERROR) << "logits's dims is " << logits_shape.size() << ", but it should be only 2-D.";
return false;
}
batch_size_ = SizeToInt(logits_shape[0]);
num_classes_ = SizeToInt(logits_shape[1]);
batch_size_ = logits_shape[0];
num_classes_ = logits_shape[1];
num_samples_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "num_samples"));
seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
num_samples_ = LongToSize(GetAttr<int64_t>(kernel_node, "num_samples"));
seed_ = GetAttr<int64_t>(kernel_node, "seed");
InitSizeLists();
return true;
@ -120,26 +120,26 @@ class RandomCategoricalGpuKernel : public GpuKernel {
void InitSizeLists() override {
// init memory
input_size_list_.push_back(sizeof(T) * batch_size_ * num_classes_);
input_size_list_.push_back(sizeof(int) * 2);
input_size_list_.push_back(sizeof(G));
input_size_list_.push_back(sizeof(G));
output_size_list_.push_back(sizeof(S) * batch_size_ * num_samples_);
for (int i = 0; i < batch_size_; i++) {
for (size_t i = 0; i < batch_size_; i++) {
workspace_size_list_.push_back(sizeof(double) * num_classes_);
}
workspace_size_list_.push_back(sizeof(double *) * batch_size_);
for (int i = 0; i < batch_size_; i++) {
for (size_t i = 0; i < batch_size_; i++) {
workspace_size_list_.push_back(sizeof(double) * num_samples_);
}
workspace_size_list_.push_back(sizeof(double *) * batch_size_);
}
private:
int batch_size_;
int num_classes_;
int num_samples_;
int seed_;
size_t batch_size_;
size_t num_classes_;
size_t num_samples_;
int64_t seed_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

Loading…
Cancel
Save