diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h index 31210fb707..3c23dc828e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,10 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TOPK_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_ +#include #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" @@ -27,7 +28,7 @@ namespace kernel { template class TopKGpuKernel : public GpuKernel { public: - TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {} + TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), input_shape_size_(0) {} ~TopKGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -40,26 +41,17 @@ class TopKGpuKernel : public GpuKernel { S *k = GetDeviceAddress(inputs, 1); T *output_addr = GetDeviceAddress(outputs, 0); S *indices = GetDeviceAddress(outputs, 1); - T *data_buff = nullptr; - S *index_buff = nullptr; - if (use_share_mem_ == false) { - data_buff = GetDeviceAddress(workspaces, 0); - index_buff = GetDeviceAddress(workspaces, 1); - } - - TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff, - reinterpret_cast(stream_ptr)); + const T init_k = std::numeric_limits::lowest(); - if (sorted_ == false) { - BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff, - reinterpret_cast(stream_ptr)); - } + FastTopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, init_k, + reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + input_shape_size_ = input_shapes.size(); for (size_t i = 0; i < input_shapes.size() - 1; i++) { outer_size_ *= input_shapes[i]; } @@ -68,13 +60,6 @@ class TopKGpuKernel : public GpuKernel { sorted_ = GetAttr(kernel_node, "sorted"); - ceil_power2_ = RoundUpPower2(inner_size_); - size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S)); - if (buffer_size > SHARED_MEM_PER_BLOCK) { - use_share_mem_ = false; - MS_LOG(INFO) << "CUDA share memory not enough, sort with RAM"; - } - InitSizeLists(); return true; } @@ -85,10 +70,6 @@ class TopKGpuKernel : public GpuKernel { input_size_list_.push_back(sizeof(S)); output_size_list_.push_back(outer_size_ * k_ * sizeof(T)); output_size_list_.push_back(outer_size_ * k_ * sizeof(S)); - if (use_share_mem_ == false) { - workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T)); - workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S)); - } } private: @@ -96,8 +77,7 @@ class TopKGpuKernel : public GpuKernel { size_t outer_size_; size_t inner_size_; size_t k_; - bool use_share_mem_; - size_t ceil_power2_; + int input_shape_size_; std::vector input_size_list_; std::vector output_size_list_; @@ -106,4 +86,4 @@ class TopKGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // TopKpuKernel +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TOPK_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh index bb654e4b58..c3dd10754f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,10 @@ #define BLOCKSIZE 256 #define MAX_DIMENSION 5 +template +void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, K *input, S *output_index, K *output_mask, + cudaStream_t stream); + template void CalRandomChoiceWithMask(const int &input_size, const int &input_shape_size, const int &d1, const int &d2, const int &d3, const int &d4, const int &d5, const int &seedc, const int &count, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rcwm_small_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rcwm_small_impl.cu new file mode 100644 index 0000000000..e8503a0bcb --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/rcwm_small_impl.cu @@ -0,0 +1,152 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/random_choice_with_mask_impl.cuh" + +// Kernel started from here +#define L2_RCWM_HELPER(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, IS_DESCEND) \ + do { \ + L2Rcwm \ + <<<1, BLOCK, 0, stream>>>(seedc, input_size, input, output_mask, output_index, k); \ + } while (0) + +#define LEFT_INSERT_THREAD_QUEUE(_k, _v) \ + do { \ + if (is_descend ? Cmp::gt(_k, warp_K_top) : Cmp::lt(_k, warp_K_top)) { \ + { \ + _Pragma("unroll") for (int i = thread_queue - 1; i > 0; --i) { \ + threadK[i] = threadK[i - 1]; \ + threadV[i] = threadV[i - 1]; \ + } \ + } \ + threadK[0] = _k; \ + threadV[0] = _v; \ + ++num_vals; \ + } \ + } while (0) + +template +__global__ void L2Rcwm(int seedc, int input_size, const K *input, K *output_mask, S *output_index, int k) { + constexpr int kNumWarps = threads_per_block / kWarpSize; + constexpr T init_K = static_cast(-2.0); + constexpr S init_V = static_cast(0); + + __shared__ T shared_K[kNumWarps * warp_queue]; + __shared__ S shared_V[kNumWarps * warp_queue]; + + curandState devState; + curand_init(seedc, threadIdx.x, 0, &devState); + + T threadK[thread_queue]; // NOLINT + S threadV[thread_queue]; // NOLINT + + T *warp_K; + S *warp_V; + + T warp_K_top = init_K; + int k_minus_1 = k - 1; + int num_vals = 0; + int limit = (input_size / kWarpSize) * kWarpSize; + int i = threadIdx.x; + + // init begin + _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { + threadK[i] = init_K; + threadV[i] = init_V; + } + + int laneId = GetLaneId(); + int warpId = threadIdx.x / kWarpSize; // 0,1,2 or 3 + + // warp shared memory start address + warp_K = shared_K + warpId * warp_queue; + warp_V = shared_V + warpId * warp_queue; + + for (int i = laneId; i < warp_queue; i += kWarpSize) { + warp_K[i] = init_K; + warp_V[i] = init_V; + } + + // sync till all threads init done + __syncwarp(); + + // insert begin + for (; i < limit; i += threads_per_block) { + T rand_num = input[i] ? __uint2float_rn(curand(&devState)) : init_K; + LEFT_INSERT_THREAD_QUEUE(rand_num, i); + + // CHECK_AND_MERGE_THREAD_QUEUE() begin + bool needSort = (num_vals == thread_queue); + needSort = __any_sync(0xffffffff, needSort); + if (!needSort) continue; + + MergeWarpQueue(threadK, threadV, warp_K, warp_V); + + num_vals = 0; + _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { + threadK[i] = init_K; + threadV[i] = init_V; + } + warp_K_top = warp_K[k_minus_1]; + __syncwarp(); + } + + if (i < input_size) { + T rand_num = input[i] ? __uint2float_rn(curand(&devState)) : init_K; + LEFT_INSERT_THREAD_QUEUE(rand_num, i); + } + + // reduce begin + MergeWarpQueue(threadK, threadV, warp_K, warp_V); + __syncthreads(); + SortBlockWide(shared_K, shared_V); + + // ship data from shared memory to output buffer + for (int i = threadIdx.x; i < k; i += blockDim.x) { + output_mask[i] = shared_K[i] > static_cast(-1.0) ? true : false; + output_index[i] = shared_V[i]; + } +} + +template +void RCWMScaleK(int seedc, int input_size, K *input, int k, S *output_index, K *output_mask, cudaStream_t stream) { + if (k <= 32) { + // num-threads-of-block, warp-queue-size, thread-queue-size + L2_RCWM_HELPER(256, 32, 2, true); + } else if (k <= 64) { + L2_RCWM_HELPER(256, 64, 3, true); + } else if (k <= 128) { + L2_RCWM_HELPER(256, 128, 3, true); + } else if (k <= 256) { + L2_RCWM_HELPER(256, 256, 4, true); + } else if (k <= 512) { + L2_RCWM_HELPER(256, 512, 8, true); + } else if (k <= 1024) { + L2_RCWM_HELPER(128, 1024, 8, true); + } else if (k <= 2048) { + L2_RCWM_HELPER(64, 2048, 8, true); + } +} + +template +void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, K *input, S *output_index, K *output_mask, + cudaStream_t stream) { + RCWMScaleK(seedc, input_size, input, count, output_index, output_mask, stream); +} + +template void CalRandomChoiceWithMaskSmall(int input_size, int seedc, int count, bool *input, + int *output_index, bool *output_mask, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu index e20a40a276..6976c46dd2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,148 +15,213 @@ */ #include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh" #include #include -size_t RoundUpPower2(size_t v) { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v++; - return v; -} +const int kMaxQueue = 128; -template -__inline__ __device__ void Swap(T *lhs, T *rhs) { - T tmp = lhs[0]; - lhs[0] = rhs[0]; - rhs[0] = tmp; -} +#define TOPK_HELPER(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, IS_DESCEND) \ + do { \ + TopKBlock \ + <<>>(outer_size, inner_size, input, output, output_index, k_cut, init_K); \ + } while (0) -template -__global__ void TopkKernel(const size_t outer, const size_t inner, const size_t ceil_power2, const T *input, const S *k, - T *output, S *indices, T *data_buff, S *index_buff) { - // default: sort with share memory - extern __shared__ T share_mem[]; - T *data_arr = share_mem; - S *index_arr = reinterpret_cast(data_arr + ceil_power2); - // sort with RAM - if (data_buff != nullptr && index_buff != nullptr) { - data_arr = data_buff + blockIdx.x * ceil_power2; - index_arr = index_buff + blockIdx.x * ceil_power2; +#define LEFT_INSERT_THREAD_QUEUE(_k, _v) \ + do { \ + if (is_descend ? CmpKV::gt(_k, _v, (*ceil_K), (*ceil_V)) : CmpKV::lt(_k, _v, (*ceil_K), (*ceil_V))) \ + break; \ + if (is_descend ? CmpKV::gt(_k, _v, warp_K_top, warp_V_top) \ + : CmpKV::lt(_k, _v, warp_K_top, warp_V_top)) { \ + { \ + _Pragma("unroll") for (int i = thread_queue - 1; i > 0; --i) { \ + threadK[i] = threadK[i - 1]; \ + threadV[i] = threadV[i - 1]; \ + } \ + } \ + threadK[0] = _k; \ + threadV[0] = _v; \ + ++num_vals; \ + } \ + } while (0) + +template +inline __device__ void TopKInBuffer(T *shared_K, S *shared_V, int *watermark, T *ceil_K, S *ceil_V, int laneId) { + constexpr int kNumWarps = threads_per_block / kWarpSize; // kNumWarps is 1024/32=32 + + // find last_K, which is max of last element of warp queue + T last_K = shared_K[laneId * warp_queue + warp_queue - 1]; + S last_V = shared_V[laneId * warp_queue + warp_queue - 1]; + + __syncwarp(); + + for (int offset = kNumWarps / 2; offset > 0; offset /= 2) { + // kNumWarps is 32 if block size is 1024 + T other_K = __shfl_down_sync(0xffffffff, last_K, offset); + S other_V = __shfl_down_sync(0xffffffff, last_V, offset); + + bool is_greater = CmpKV::gt(other_K, other_V, last_K, last_V); + ConditionalAssign(is_greater, &last_K, other_K); + ConditionalAssign(is_greater, &last_V, other_V); } + __syncwarp(); - for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) { - data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits::max(); - index_arr[i] = i; + if (laneId == 0) { + *ceil_K = last_K; + *ceil_V = last_V; } - __syncthreads(); + __syncwarp(); - for (size_t i = 2; i <= ceil_power2; i <<= 1) { - for (size_t j = (i >> 1); j > 0; j >>= 1) { - for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { - size_t tid_comp = tid ^ j; - if (tid_comp > tid) { - if ((tid & i) == 0) { - if (data_arr[tid] > data_arr[tid_comp]) { - Swap(&data_arr[tid], &data_arr[tid_comp]); - Swap(&index_arr[tid], &index_arr[tid_comp]); - } - } else { - if (data_arr[tid] < data_arr[tid_comp]) { - Swap(&data_arr[tid], &data_arr[tid_comp]); - Swap(&index_arr[tid], &index_arr[tid_comp]); - } - } - } - } - __syncthreads(); - } + // calculate index cut by last_K + int L = 0; + int R = warp_queue; + while (L < R) { + int m = (L + R) / 2; + CmpKV::gt(shared_K[laneId * warp_queue + m], shared_V[laneId * warp_queue + m], (*ceil_K), (*ceil_V)) + ? L = m + 1 + : R = m; } + __syncwarp(); - for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) { - output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1]; - indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1]; + // merge top number which value is greater than last_K + for (int offset = kNumWarps / 2; offset > 0; offset /= 2) { + R += __shfl_down_sync(0xffffffff, R, offset); } -} -template -void TopK(const size_t &outer, const size_t &inner, const T *input, const S *k, T *output, S *indices, T *data_buff, - S *index_buff, cudaStream_t stream) { - size_t ceil_power2 = RoundUpPower2(inner); - size_t share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0; - size_t thread_num = std::min(ceil_power2, static_cast(GET_THREADS)); - TopkKernel<<>>(outer, inner, ceil_power2, input, k, output, indices, data_buff, - index_buff); + __syncwarp(); + + if (laneId == 0) { + watermark[0] = R; + } + __syncwarp(); } -template -__global__ void BitonicSortByKeyKernel(const size_t outer, const size_t inner, const size_t ceil_power2, T *input, - S *indices, T *data_buff, S *index_buff) { - // default: sort with share memory - extern __shared__ T share_mem[]; - T *data_arr = share_mem; - S *index_arr = reinterpret_cast(data_arr + ceil_power2); - // sort with RAM - if (data_buff != nullptr && index_buff != nullptr) { - data_arr = data_buff + blockIdx.x * ceil_power2; - index_arr = index_buff + blockIdx.x * ceil_power2; +template +inline __device__ void TopKStep(const int &outer_size, const int &inner_size, const T *input, T *output, + S *output_index, S k_cut, const T &init_K, const int &outer_id, T *shared_K, + S *shared_V, int *watermark, T *threadK, S *threadV, T *ceil_K, S *ceil_V, S *k_prime) { + constexpr int kNumWarps = threads_per_block / kWarpSize; + constexpr S init_V = static_cast(-1); + + T *warp_K; + S *warp_V; + + T warp_K_top = init_K; + S warp_V_top = init_V; + int k_minus_1 = (k_cut <= kMaxQueue ? k_cut - 1 : kMaxQueue - 1); + int num_vals = 0; + int limit = (inner_size / kWarpSize) * kWarpSize; + + _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { + threadK[i] = init_K; + threadV[i] = init_V; } - for (size_t i = threadIdx.x; i < ceil_power2; i += blockDim.x) { - data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits::max(); - index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits::max(); + int laneId = GetLaneId(); + int warpId = threadIdx.x / kWarpSize; // 0,1,2 or 3 + + warp_K = shared_K + warpId * warp_queue; + warp_V = shared_V + warpId * warp_queue; + + for (int i = laneId; i < warp_queue; i += kWarpSize) { + warp_K[i] = init_K; + warp_V[i] = init_V; } - __syncthreads(); - for (size_t i = 2; i <= ceil_power2; i <<= 1) { - for (size_t j = (i >> 1); j > 0; j >>= 1) { - for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { - size_t tid_comp = tid ^ j; - if (tid_comp > tid) { - if ((tid & i) == 0) { - if (index_arr[tid] > index_arr[tid_comp]) { - Swap(&data_arr[tid], &data_arr[tid_comp]); - Swap(&index_arr[tid], &index_arr[tid_comp]); - } - } else { - if (index_arr[tid] < index_arr[tid_comp]) { - Swap(&data_arr[tid], &data_arr[tid_comp]); - Swap(&index_arr[tid], &index_arr[tid_comp]); - } - } - } - } - __syncthreads(); + __syncwarp(); + + int i = threadIdx.x; + for (; i < limit; i += threads_per_block) { + LEFT_INSERT_THREAD_QUEUE((input[outer_id * inner_size + i]), (outer_id * inner_size + i)); + + bool needSort = (num_vals == thread_queue); + needSort = __any_sync(0xffffffff, needSort); + if (!needSort) continue; + + MergeWarpQueue(threadK, threadV, warp_K, warp_V); + + num_vals = 0; + _Pragma("unroll") for (int i = 0; i < thread_queue; ++i) { + threadK[i] = init_K; + threadV[i] = init_V; } + warp_K_top = warp_K[k_minus_1]; + warp_V_top = warp_V[k_minus_1]; + __syncwarp(); + } + + if (i < inner_size) { + LEFT_INSERT_THREAD_QUEUE((input[outer_id * inner_size + i]), (outer_id * inner_size + i)); + } + + MergeWarpQueue(threadK, threadV, warp_K, warp_V); + __syncthreads(); + + if (k_cut > kMaxQueue && warpId == 0) { + TopKInBuffer(shared_K, shared_V, watermark, ceil_K, + ceil_V, laneId); } + __syncthreads(); + + SortBlockWide(shared_K, shared_V); + + S k_step = (*k_prime) + watermark[0] <= k_cut ? watermark[0] : k_cut - (*k_prime); + for (int i = threadIdx.x; i < k_step; i += blockDim.x) { + output[outer_id * k_cut + (*k_prime) + i] = shared_K[i]; + output_index[outer_id * k_cut + (*k_prime) + i] = shared_V[i] % inner_size; + } + *k_prime += k_step; + __syncthreads(); +} + +template +__global__ void TopKBlock(int outer_size, int inner_size, const T *input, T *output, S *output_index, S k_cut, + const T init_K) { + constexpr int kNumWarps = threads_per_block / kWarpSize; + + __shared__ T shared_K[kNumWarps * warp_queue]; + __shared__ S shared_V[kNumWarps * warp_queue]; + __shared__ int watermark[1]; + __shared__ T ceil_K; + __shared__ S ceil_V; - for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { - input[blockIdx.x * inner + tid] = data_arr[tid]; - indices[blockIdx.x * inner + tid] = index_arr[tid]; + T threadK[thread_queue]; // NOLINT + S threadV[thread_queue]; // NOLINT + + for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < blockDim.x * outer_size; + t_idx += blockDim.x * gridDim.x) { + S k_prime = 0; + int outer_id = t_idx / blockDim.x; + ceil_K = -init_K; + ceil_V = -1; + watermark[0] = k_cut; + do { + TopKStep( + outer_size, inner_size, input, output, output_index, k_cut, init_K, outer_id, shared_K, shared_V, watermark, + threadK, threadV, &ceil_K, &ceil_V, &k_prime); + } while (k_prime < k_cut); } } template -void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff, - cudaStream_t stream) { - size_t ceil_power2 = RoundUpPower2(inner); - size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); - if (share_mem > SHARED_MEM_PER_BLOCK) { - share_mem = 0; +void FastTopK(const int outer_size, const int inner_size, const T *input, const S *k, T *output, S *output_index, + const T init_K, cudaStream_t stream) { + int block_num_limit = outer_size < 128 ? outer_size : 128; + S k_cut = 0; + cudaMemcpy(&k_cut, k, sizeof(S), cudaMemcpyDeviceToHost); + if (k_cut > inner_size) k_cut = inner_size; + + if (k_cut <= 32) { + // num-threads-of-block, warp-queue-size, thread-queue-size + TOPK_HELPER(256, 32, 2, true); + } else if (k_cut <= 64) { + TOPK_HELPER(256, 64, 3, true); + } else if (k_cut <= 128) { + TOPK_HELPER(256, 128, 3, true); } else { - data_buff = nullptr; - index_buff = nullptr; + TOPK_HELPER(1024, 128, 3, true); } - size_t thread_num = std::min(ceil_power2, static_cast(GET_THREADS)); - BitonicSortByKeyKernel<<>>(outer, inner, ceil_power2, input, indices, data_buff, - index_buff); } -template void TopK(const size_t &outer, const size_t &inner, const float *input_addr, const int *k, float *output, - int *indices, float *data_buff, int *index_buff, cudaStream_t stream); -template void BitonicSortByKey(const size_t &outer, const size_t &inner, float *input, int *indices, float *data_buff, - int *index_buff, cudaStream_t stream); +template void FastTopK(const int outer_size, const int inner_size, const float *input, const int *k, float *output, + int *output_index, const float init_K, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh index a972995714..7ecb392c1b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,19 +14,14 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_ #include #include "runtime/device/gpu/cuda_common.h" template -void TopK(const size_t &outer, const size_t &inner, const T *input_addr, const S *k, T *output, S *indices, - T *data_buff, S *index_buff, cudaStream_t stream); +void FastTopK(const int outer, const int inner, const T *input_addr, const S *k, T *output, S *indices, const T initK, + cudaStream_t stream); -template -void BitonicSortByKey(const size_t &outer, const size_t &inner, T *input, S *indices, T *data_buff, S *index_buff, - cudaStream_t stream); -size_t RoundUpPower2(size_t v); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TOPK_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh new file mode 100644 index 0000000000..a960b6de6d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_lib.cuh @@ -0,0 +1,479 @@ +/** + * Copyright 2020-2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +constexpr int kWarpSize = 32; + +constexpr __host__ __device__ int Log2(int n, int p = 0) { return (n <= 1) ? p : Log2(n / 2, p + 1); } +constexpr __host__ __device__ bool IsPow2(int v) { return (v && !(v & (v - 1))); } +constexpr __host__ __device__ int NextPow2(int v) { return (IsPow2(v) ? 2 * v : (1 << static_cast(Log2(v) + 1))); } + +__device__ __forceinline__ int GetLaneId() { + int laneId; + asm("mov.u32 %0, %%laneid;" : "=r"(laneId)); + return laneId; +} + +template +struct CmpKV { + __device__ static inline bool gt(T k1, S v1, T k2, S v2) { return k1 > k2 || (k1 == k2 && v1 < v2); } + __device__ static inline bool lt(T k1, S v1, T k2, S v2) { return k1 < k2 || (k1 == k2 && v1 > v2); } +}; + +template +struct Cmp { + __device__ static inline bool lt(T a, T b) { return a < b; } + __device__ static inline bool gt(T a, T b) { return a > b; } +}; + +template +inline __device__ T shfl_xor(const T val, int laneMask, int width = kWarpSize) { + return __shfl_xor_sync(0xffffffff, val, laneMask, width); +} + +template +inline __device__ void L2CompareAndSwap(T *a, S *b, int i_1, int i_2) { + bool swap = + is_descend ? CmpKV::gt(a[i_1], b[i_1], a[i_2], b[i_2]) : CmpKV::lt(a[i_1], b[i_1], a[i_2], b[i_2]); + + if (!swap) return; + + T a_tmp = a[i_1]; + a[i_1] = a[i_2]; + a[i_2] = a_tmp; + + T b_tmp = b[i_1]; + b[i_1] = b[i_2]; + b[i_2] = b_tmp; +} + +template +inline __device__ void ConditionalAssign(bool is_assign, T *x, const T &y) { + (*x) = is_assign ? y : (*x); +} + +// Merge pairs of lists smaller than threads-per-block +// NumThreads is 128 +// N is 2, 1 etc +// L is 32, 64 etc +template +inline __device__ void BlockSortSmallK(T *list_k, S *list_v) { + int mergeId = threadIdx.x / L; + int tid = threadIdx.x % L; + + list_k += 2 * L * mergeId; + list_v += 2 * L * mergeId; + + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + if (AllThreads || (static_cast(threadIdx.x) < N * L)) { + L2CompareAndSwap(list_k, list_v, pos, pos + stride); + } + + __syncthreads(); + + _Pragma("unroll") for (int stride = L / 2; stride > 0; stride /= 2) { + int pos = 2 * tid - (tid & (stride - 1)); + + if (AllThreads || (static_cast(threadIdx.x) < N * L)) { + L2CompareAndSwap(list_k, list_v, pos, pos + stride); + } + + __syncthreads(); + } +} + +// Merge pairs of lists larger than threads-per-block +template +inline __device__ void BlockSortBigK(T *list_k, S *list_v) { + constexpr int kLoopPerThread = L / NumThreads; + + _Pragma("unroll") for (int loop = 0; loop < kLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + L2CompareAndSwap(list_k, list_v, pos, pos + stride); + } + + __syncthreads(); + + constexpr int kSecondLoopPerThread = FullMerge ? kLoopPerThread : kLoopPerThread / 2; + + _Pragma("unroll") for (int stride = L / 2; stride > 0; stride /= 2) { + _Pragma("unroll") for (int loop = 0; loop < kSecondLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = 2 * tid - (tid & (stride - 1)); + L2CompareAndSwap(list_k, list_v, pos, pos + stride); + } + __syncthreads(); + } +} + +/// Merging lists smaller than threads-per-block +template +inline __device__ void SortBlockStep(T *list_k, S *list_v) { + if (L <= NumThreads) { + int kNumParallelMerges = NumThreads / L; + int kNumIterations = N / kNumParallelMerges; + + if (N < kNumParallelMerges) { + BlockSortSmallK(list_k, list_v); + } else { + _Pragma("unroll") for (int i = 0; i < kNumIterations; ++i) { + int start = i * kNumParallelMerges * 2 * L; + BlockSortSmallK(list_k + start, list_v + start); + } + } + } else { + _Pragma("unroll") for (int i = 0; i < N; ++i) { + int start = i * 2 * L; + BlockSortBigK(list_k + start, list_v + start); + } + } +} + +// Block-wide merge +template +inline __device__ void SortBlockWide(T *shared_K, S *shared_V) { + if (NumWarps == 2) { + SortBlockStep(shared_K, shared_V); + } else if (NumWarps == 4) { + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, + shared_V); + } else if (NumWarps == 8) { + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, + shared_V); + } else if (NumWarps == 16) { + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + } else if (NumWarps == 32) { + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + SortBlockStep(shared_K, shared_V); + } +} + +template +inline __device__ void BitonicSortWarpLE16(T *k, S *v) { + int laneId = GetLaneId(); + + if (!IsBitonic) { + // Reverse the first comparison stage. head-tail swap. + T other_K = shfl_xor((*k), 2 * L - 1); + S other_V = shfl_xor((*v), 2 * L - 1); + + bool small = !(laneId & L); + bool small_compare = small ? CmpKV::gt((*k), (*v), other_K, other_V) : + CmpKV::lt((*k), (*v), other_K, other_V); + bool small_compare_descend = is_descend ? small_compare : !small_compare; + ConditionalAssign(small_compare_descend, k, other_K); + ConditionalAssign(small_compare_descend, v, other_V); + } + + _Pragma("unroll") for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { + T other_K = shfl_xor((*k), stride); + S other_V = shfl_xor((*v), stride); + + bool small = !(laneId & stride); + bool small_compare = small ? CmpKV::gt((*k), (*v), other_K, other_V) : + CmpKV::lt((*k), (*v), other_K, other_V); + bool small_compare_descend = is_descend ? small_compare : !small_compare; + ConditionalAssign(small_compare_descend, k, other_K); + ConditionalAssign(small_compare_descend, v, other_V); + } +} + +template +struct MergeWarpStepBitonic {}; + +// All merges call this +template +struct MergeWarpStepBitonic { + static inline __device__ void merge(T k[1], S v[1]) { BitonicSortWarpLE16(&k[0], &v[0]); } +}; + +template +struct MergeWarpStepBitonic { + static inline __device__ void merge(T k[N], S v[N]) { + _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { L2CompareAndSwap(k, v, i, i + N / 2); } + + { + T newK[N / 2]; + S newV[N / 2]; + + _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { + newK[i] = k[i]; + newV[i] = v[i]; + } + + MergeWarpStepBitonic::merge(newK, newV); + + _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { + k[i] = newK[i]; + v[i] = newV[i]; + } + } + + { + T newK[N / 2]; + S newV[N / 2]; + + _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { + newK[i] = k[i + N / 2]; + newV[i] = v[i + N / 2]; + } + + MergeWarpStepBitonic::merge(newK, newV); + + _Pragma("unroll") for (int i = 0; i < N / 2; ++i) { + k[i + N / 2] = newK[i]; + v[i + N / 2] = newV[i]; + } + } + } +}; + +// Low recursion +template +struct MergeWarpStepBitonic { + static inline __device__ void merge(T k[N], S v[N]) { + constexpr int kNextHighestPowerOf2 = NextPow2(N); + + _Pragma("unroll") for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { + L2CompareAndSwap(k, v, i, i + kNextHighestPowerOf2 / 2); + } + + constexpr int kLowSize = N - kNextHighestPowerOf2 / 2; + constexpr int kHighSize = kNextHighestPowerOf2 / 2; + { + T newK[kLowSize]; + S newV[kLowSize]; + + _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { + newK[i] = k[i]; + newV[i] = v[i]; + } + + constexpr bool kLowIsPowerOf2 = IsPow2(N - kNextHighestPowerOf2 / 2); + MergeWarpStepBitonic::merge(newK, newV); + + _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { + k[i] = newK[i]; + v[i] = newV[i]; + } + } + + { + T newK[kHighSize]; + S newV[kHighSize]; + + _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; + } + + constexpr bool kHighIsPowerOf2 = IsPow2(kNextHighestPowerOf2 / 2); + MergeWarpStepBitonic::merge(newK, newV); + + _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; + } + } + } +}; + +// High recursion +template +struct MergeWarpStepBitonic { + static inline __device__ void merge(T k[N], S v[N]) { + constexpr int kNextHighestPowerOf2 = NextPow2(N); + + _Pragma("unroll") for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { + L2CompareAndSwap(k, v, i, i + kNextHighestPowerOf2 / 2); + } + + constexpr int kLowSize = kNextHighestPowerOf2 / 2; + constexpr int kHighSize = N - kNextHighestPowerOf2 / 2; + { + T newK[kLowSize]; + S newV[kLowSize]; + + _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { + newK[i] = k[i]; + newV[i] = v[i]; + } + + constexpr bool kLowIsPowerOf2 = IsPow2(kNextHighestPowerOf2 / 2); + MergeWarpStepBitonic::merge(newK, newV); + + _Pragma("unroll") for (int i = 0; i < kLowSize; ++i) { + k[i] = newK[i]; + v[i] = newV[i]; + } + } + + { + T newK[kHighSize]; + S newV[kHighSize]; + + _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; + } + + constexpr bool kHighIsPowerOf2 = IsPow2(N - kNextHighestPowerOf2 / 2); + MergeWarpStepBitonic::merge(newK, newV); + + _Pragma("unroll") for (int i = 0; i < kHighSize; ++i) { + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; + } + } + } +}; + +/// Merges two sets of registers across the warp of any size; +template +inline __device__ void MergeWarpByRegister(T k1[N1], S v1[N1], T k2[N2], S v2[N2]) { + constexpr int kSmallestN = N1 < N2 ? N1 : N2; + + _Pragma("unroll") for (int i = 0; i < kSmallestN; ++i) { + T &ka = k1[N1 - 1 - i]; + S &va = v1[N1 - 1 - i]; + + T &kb = k2[i]; + S &vb = v2[i]; + + T other_Ka; + S other_Va; + + if (FullMerge) { + other_Ka = shfl_xor(ka, kWarpSize - 1); + other_Va = shfl_xor(va, kWarpSize - 1); + } + + T other_Kb = shfl_xor(kb, kWarpSize - 1); + S other_Vb = shfl_xor(vb, kWarpSize - 1); + + bool swapa = is_descend ? CmpKV::gt(ka, va, other_Kb, other_Vb) : CmpKV::lt(ka, va, other_Kb, other_Vb); + ConditionalAssign(swapa, &ka, other_Kb); + ConditionalAssign(swapa, &va, other_Vb); + + if (FullMerge) { + bool swapb = is_descend ? CmpKV::lt(kb, vb, other_Ka, other_Va) : + CmpKV::gt(kb, vb, other_Ka, other_Va); + ConditionalAssign(swapb, &kb, other_Ka); + ConditionalAssign(swapb, &vb, other_Va); + } + } + + MergeWarpStepBitonic::merge(k1, v1); + if (FullMerge) { + MergeWarpStepBitonic::merge(k2, v2); + } +} + +// Recursive template that uses the above bitonic merge +template +struct SortWarpStepBitonic { + static inline __device__ void Sort(T k[N], S v[N]) { + constexpr int kSizeA = N / 2; + constexpr int kSizeB = N - kSizeA; + + T aK[kSizeA]; + S aV[kSizeA]; + + _Pragma("unroll") for (int i = 0; i < kSizeA; ++i) { + aK[i] = k[i]; + aV[i] = v[i]; + } + + // Recursive sort + SortWarpStepBitonic::Sort(aK, aV); + + T bK[kSizeB]; + S bV[kSizeB]; + + _Pragma("unroll") for (int i = 0; i < kSizeB; ++i) { + bK[i] = k[i + kSizeA]; + bV[i] = v[i + kSizeA]; + } + + SortWarpStepBitonic::Sort(bK, bV); + + // Merge halves + MergeWarpByRegister(aK, aV, bK, bV); + + _Pragma("unroll") for (int i = 0; i < kSizeA; ++i) { + k[i] = aK[i]; + v[i] = aV[i]; + } + + _Pragma("unroll") for (int i = 0; i < kSizeB; ++i) { + k[i + kSizeA] = bK[i]; + v[i + kSizeA] = bV[i]; + } + } +}; + +template +struct SortWarpStepBitonic { + static inline __device__ void Sort(T k[1], S v[1]) { + // up to warp-size/2 + BitonicSortWarpLE16(&k[0], &v[0]); + BitonicSortWarpLE16(&k[0], &v[0]); + BitonicSortWarpLE16(&k[0], &v[0]); + BitonicSortWarpLE16(&k[0], &v[0]); + BitonicSortWarpLE16(&k[0], &v[0]); + } +}; + +template +inline __device__ void SortWarpByRegister(T k[N], S v[N]) { + SortWarpStepBitonic::Sort(k, v); +} + +template +inline __device__ void MergeWarpQueue(T *threadK, S *threadV, T *warp_K, S *warp_V) { + int laneId = GetLaneId(); + SortWarpByRegister(threadK, threadV); + + constexpr int kWarpQueueRegisters = warp_queue / kWarpSize; + T warp_KRegisters[kWarpQueueRegisters]; + S warp_VRegisters[kWarpQueueRegisters]; + _Pragma("unroll") for (int i = 0; i < kWarpQueueRegisters; ++i) { + warp_KRegisters[i] = warp_K[i * kWarpSize + laneId]; + warp_VRegisters[i] = warp_V[i * kWarpSize + laneId]; + } + __syncwarp(); + MergeWarpByRegister(warp_KRegisters, warp_VRegisters, + threadK, threadV); + _Pragma("unroll") for (int i = 0; i < kWarpQueueRegisters; ++i) { + warp_K[i * kWarpSize + laneId] = warp_KRegisters[i]; + warp_V[i * kWarpSize + laneId] = warp_VRegisters[i]; + } + __syncwarp(); +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h index afbade332f..178273beed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/random/random_choice_with_mask_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,17 +39,22 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { T *input = GetDeviceAddress(inputs, 0); S *output_index = GetDeviceAddress(outputs, 0); T *output_mask = GetDeviceAddress(outputs, 1); - S *index_buff = GetDeviceAddress(workspaces, 0); - S *mask_buff = GetDeviceAddress(workspaces, 1); - S *rank_buff = GetDeviceAddress(workspaces, 2); - S *Tnum_buff = GetDeviceAddress(workspaces, 3); - S *tmp_buff = GetDeviceAddress(workspaces, 4); - void *States = GetDeviceAddress(workspaces, 5); - curandState *devStates = reinterpret_cast(States); - CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], input_shape_5D_[2], - input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, output_index, output_mask, - index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, devStates, - reinterpret_cast(stream_ptr)); + if (count_ > kSmallK || input_shape_size_ > 1) { + S *index_buff = GetDeviceAddress(workspaces, 0); + S *mask_buff = GetDeviceAddress(workspaces, 1); + S *rank_buff = GetDeviceAddress(workspaces, 2); + S *Tnum_buff = GetDeviceAddress(workspaces, 3); + S *tmp_buff = GetDeviceAddress(workspaces, 4); + void *States = GetDeviceAddress(workspaces, 5); + curandState *devStates = reinterpret_cast(States); + CalRandomChoiceWithMask(input_size_, input_shape_size_, input_shape_5D_[0], input_shape_5D_[1], + input_shape_5D_[2], input_shape_5D_[3], input_shape_5D_[4], seedc_, count_, input, + output_index, output_mask, index_buff, mask_buff, rank_buff, Tnum_buff, tmp_buff, + devStates, reinterpret_cast(stream_ptr)); + } else { + CalRandomChoiceWithMaskSmall(input_size_, seedc_, count_, input, output_index, output_mask, + reinterpret_cast(stream_ptr)); + } return true; } @@ -94,7 +99,9 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { } count_ = static_cast(GetAttr(kernel_node, "count")); // upper ceiling for input for ceil_power2 - ceil_power2_ = RcwmRoundUpPower2(input_size_); + if (count_ > kSmallK || input_shape_size_ > 1) { + ceil_power2_ = RcwmRoundUpPower2(input_size_); + } InitSizeLists(); return true; } @@ -104,16 +111,19 @@ class RandomChoiceWithMaskGpuKernel : public GpuKernel { input_size_list_.push_back(input_size_ * sizeof(T)); output_size_list_.push_back(count_ * input_shape_size_ * sizeof(S)); output_size_list_.push_back(count_ * sizeof(T)); - workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); - workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); - workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); - int blocknum = std::ceil(static_cast(ceil_power2_) / BLOCKSIZE); - workspace_size_list_.push_back(blocknum * sizeof(S)); - workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); - workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); + if (count_ > kSmallK || input_shape_size_ > 1) { + workspace_size_list_.push_back(input_size_ * input_shape_size_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + int blocknum = std::ceil(static_cast(ceil_power2_) / BLOCKSIZE); + workspace_size_list_.push_back(blocknum * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(S)); + workspace_size_list_.push_back(ceil_power2_ * sizeof(curandState)); + } } private: + const int kSmallK = 2048; int input_shape_size_; int seedc_; int input_size_; diff --git a/tests/st/ops/gpu/test_random_choice_with_mask.py b/tests/st/ops/gpu/test_random_choice_with_mask.py index 07bbd26e4e..5944e15c3c 100644 --- a/tests/st/ops/gpu/test_random_choice_with_mask.py +++ b/tests/st/ops/gpu/test_random_choice_with_mask.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P + class RCWM_count_in(nn.Cell): def __init__(self): super(RCWM_count_in, self).__init__() @@ -29,6 +30,7 @@ class RCWM_count_in(nn.Cell): def construct(self, x): return self.RCWM_count_in(x) + class RCWM_count_out(nn.Cell): def __init__(self): super(RCWM_count_out, self).__init__() @@ -37,6 +39,7 @@ class RCWM_count_out(nn.Cell): def construct(self, x): return self.RCWM_count_out(x) + class RCWM_3D(nn.Cell): def __init__(self): super(RCWM_3D, self).__init__() @@ -45,6 +48,16 @@ class RCWM_3D(nn.Cell): def construct(self, x): return self.RCWM_3D(x) + +class RCWM_1D(nn.Cell): + def __init__(self): + super(RCWM_1D, self).__init__() + self.RCWM_1D = P.RandomChoiceWithMask(count=10, seed=9) + + def construct(self, x): + return self.RCWM_1D(x) + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -58,12 +71,14 @@ def test_RCWM_3D(): assert output1.shape == expect1 assert output2.shape == expect2 + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_RCWM_count_out(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], + [0, 0, 0, 1]]).astype(np.bool)) expect1 = (10, 2) expect2 = (10,) rcwm = RCWM_count_out() @@ -71,15 +86,36 @@ def test_RCWM_count_out(): assert output1.shape == expect1 assert output2.shape == expect2 + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_RCWM_count_in(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + input_tensor = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], + [0, 0, 0, 1]]).astype(np.bool)) expect1 = (4, 2) expect2 = (4,) rcwm = RCWM_count_in() output1, output2 = rcwm(input_tensor) assert output1.shape == expect1 assert output2.shape == expect2 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_RCWM_1D(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + input_tensor = Tensor( + np.array([1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1]).astype(np.bool)) + expect_index = np.array([[11], [9], [2], [15], [10], [7], + [8], [0], [0], [0]]).astype(np.int32) + expect_mask = np.array( + [True, True, True, True, True, True, True, True, False, False]) + rcwm = RCWM_1D() + output1, output2 = rcwm(input_tensor) + print(output1.asnumpy()) + print(output2) + assert np.array_equal(output1.asnumpy(), expect_index) + assert np.array_equal(output2.asnumpy(), expect_mask) diff --git a/tests/st/ops/gpu/test_topk_op.py b/tests/st/ops/gpu/test_topk_op.py index 83cd8e6403..e938bd73b8 100644 --- a/tests/st/ops/gpu/test_topk_op.py +++ b/tests/st/ops/gpu/test_topk_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-21 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from mindspore.ops import operations as P @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_topk(): +def test_topk_small_2d(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") x_np = np.random.rand(3, 4).astype(np.float32) @@ -36,7 +36,20 @@ def test_topk(): x_np = np.random.rand(3, 4).astype(np.float32) k = 4 ms_output = P.TopK(False)(Tensor(x_np), k) - assert np.allclose(ms_output[0].asnumpy(), x_np) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_topk_3d(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x_np = np.random.rand(2, 256, 128).astype(np.float32) + k = 4 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) x_np = np.random.rand(2, 3, 4).astype(np.float32) k = 2 @@ -44,6 +57,12 @@ def test_topk(): np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] assert np.allclose(ms_output[0].asnumpy(), np_output) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_topk_big_2d(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") x_np = np.random.rand(512, 1024).astype(np.float32) k = 512 ms_output = P.TopK(True)(Tensor(x_np), k) @@ -51,32 +70,69 @@ def test_topk(): assert np.allclose(ms_output[0].asnumpy(), np_output) # sorted elements num greater than max thread per block - x_np = np.random.rand(512, 2048).astype(np.float32) + x_np = np.random.rand(128, 2048).astype(np.float32) k = 1 ms_output = P.TopK(True)(Tensor(x_np), k) np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] assert np.allclose(ms_output[0].asnumpy(), np_output) - x_np = np.random.rand(512, 2048).astype(np.float32) + x_np = np.random.rand(32, 2048).astype(np.float32) k = 2048 ms_output = P.TopK(True)(Tensor(x_np), k) np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] assert np.allclose(ms_output[0].asnumpy(), np_output) # sorted elements num greater than max share memory per block - x_np = np.random.rand(512, 40960).astype(np.float32) + x_np = np.random.rand(16, 40960).astype(np.float32) k = 1 ms_output = P.TopK(True)(Tensor(x_np), k) np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] assert np.allclose(ms_output[0].asnumpy(), np_output) - x_np = np.random.rand(512, 40960).astype(np.float32) - k = 40960 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_topk_big_k(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x_np = np.random.rand(8, 40960).astype(np.float32) + k = 4096 ms_output = P.TopK(True)(Tensor(x_np), k) np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] assert np.allclose(ms_output[0].asnumpy(), np_output) - x_np = np.random.rand(512, 40960).astype(np.float32) - k = 40960 - ms_output = P.TopK(False)(Tensor(x_np), k) - assert np.allclose(ms_output[0].asnumpy(), x_np) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_topk_1d(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x_np = np.random.rand(12).astype(np.float32) + k = 4 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np)[::-1][0:k] + + assert np.allclose(ms_output[0].asnumpy(), np_output) + x_np = np.random.rand(1200).astype(np.float32) + k = 256 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np)[::-1][0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(250000).astype(np.float32) + k = 2000 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np)[::-1][0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(10240).astype(np.float32) + k = 4096 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np)[::-1][0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(720).astype(np.float32) + k = 720 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np)[::-1][0:k] + assert np.allclose(ms_output[0].asnumpy()[:k], np_output)