|
|
|
@ -18,6 +18,8 @@
|
|
|
|
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANDOM_RANDOM_CHOICE_WITH_MASK_GPU_KERNEL_H_
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <chrono>
|
|
|
|
|
#include <random>
|
|
|
|
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
|
|
|
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh"
|
|
|
|
@ -27,7 +29,8 @@ namespace kernel {
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
|
class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|
|
|
|
public:
|
|
|
|
|
RandomChoiceWithMaskGpuKernel() : input_shape_size_(0), seedc_(0), input_size_(1), count_(0), ceil_power2_(0) {}
|
|
|
|
|
RandomChoiceWithMaskGpuKernel()
|
|
|
|
|
: input_shape_size_(0), seed_(0), seed2_(0), input_size_(1), count_(0), ceil_power2_(0) {}
|
|
|
|
|
~RandomChoiceWithMaskGpuKernel() override = default;
|
|
|
|
|
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
|
|
|
@ -39,6 +42,14 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|
|
|
|
T *input = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
S *output_index = GetDeviceAddress<S>(outputs, 0);
|
|
|
|
|
T *output_mask = GetDeviceAddress<T>(outputs, 1);
|
|
|
|
|
int seedc = 0;
|
|
|
|
|
if (seed2_ != 0) {
|
|
|
|
|
seedc = seed2_;
|
|
|
|
|
} else if (seed_ != 0) {
|
|
|
|
|
seedc = seed_;
|
|
|
|
|
} else {
|
|
|
|
|
seedc = generator_();
|
|
|
|
|
}
|
|
|
|
|
if (count_ > kSmallK || input_shape_size_ > 1) {
|
|
|
|
|
S *index_buff = GetDeviceAddress<S>(workspaces, 0);
|
|
|
|
|
S *mask_buff = GetDeviceAddress<S>(workspaces, 1);
|
|
|
|
@ -48,17 +59,18 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|
|
|
|
void *States = GetDeviceAddress<void *>(workspaces, 5);
|
|
|
|
|
curandState *devStates = reinterpret_cast<curandState *>(States);
|
|
|
|
|
CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1],
|
|
|
|
|
input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input,
|
|
|
|
|
input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc, count_, input,
|
|
|
|
|
output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff,
|
|
|
|
|
devStates, reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
} else {
|
|
|
|
|
CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc_, count_, input, output_index, output_mask,
|
|
|
|
|
CalRandomChoiceWithMaskSmall<float, S, T>(input_size_, seedc, count_, input, output_index, output_mask,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
|
uint32_t time_interval = std::chrono::system_clock::now().time_since_epoch().count();
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
|
if (input_num != 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but RandomChoiceWithMask needs 1 input.";
|
|
|
|
@ -84,15 +96,10 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|
|
|
|
while (input_shape_5D_.size() != MAX_DIMENSION) {
|
|
|
|
|
input_shape_5D_.insert(input_shape_5D_.begin(), 1);
|
|
|
|
|
}
|
|
|
|
|
// init seedc_
|
|
|
|
|
int seed = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
|
|
|
|
|
int seed2 = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2"));
|
|
|
|
|
if (seed2 != 0)
|
|
|
|
|
seedc_ = seed2;
|
|
|
|
|
else if (seed != 0)
|
|
|
|
|
seedc_ = seed;
|
|
|
|
|
else
|
|
|
|
|
seedc_ = time(NULL);
|
|
|
|
|
// init seedc
|
|
|
|
|
seed_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed"));
|
|
|
|
|
seed2_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "seed2"));
|
|
|
|
|
generator_.seed(time_interval);
|
|
|
|
|
// init memory
|
|
|
|
|
for (size_t i = 0; i < input_shape.size(); i++) {
|
|
|
|
|
input_size_ *= input_shape[i];
|
|
|
|
@ -125,10 +132,12 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel {
|
|
|
|
|
private:
|
|
|
|
|
const int kSmallK = 2048;
|
|
|
|
|
int input_shape_size_;
|
|
|
|
|
int seedc_;
|
|
|
|
|
int seed_;
|
|
|
|
|
int seed2_;
|
|
|
|
|
int input_size_;
|
|
|
|
|
int count_;
|
|
|
|
|
int ceil_power2_;
|
|
|
|
|
std::mt19937 generator_;
|
|
|
|
|
std::vector<int> input_shape_5D_;
|
|
|
|
|
std::vector<size_t> input_size_list_;
|
|
|
|
|
std::vector<size_t> output_size_list_;
|
|
|
|
|