From c10e07734cd6d92038402310298e19ad3d88f5e5 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 15 Jul 2020 19:51:30 +0800 Subject: [PATCH] gpu support TopK kernel --- mindspore/ccsrc/CMakeLists.txt | 2 +- .../gpu/arrays/topk_gpu_kernel.cc | 29 ++++ .../gpu/arrays/topk_gpu_kernel.h | 110 ++++++++++++ .../gpu/cuda_impl/topk_impl.cu | 162 ++++++++++++++++++ .../gpu/cuda_impl/topk_impl.cuh | 32 ++++ .../ccsrc/runtime/device/gpu/cuda_common.h | 4 + tests/st/ops/gpu/test_topk_op.py | 82 +++++++++ 7 files changed, 420 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh create mode 100644 tests/st/ops/gpu/test_topk_op.py diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 53300acda4..472783c501 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -44,7 +44,7 @@ if(ENABLE_GPU) "backend/kernel_compiler/akg/akg_kernel_attrs_process.cc" ) - list(APPEND CUDA_NVCC_FLAGS -arch=sm_53) + list(APPEND CUDA_NVCC_FLAGS -arch=sm_53 --expt-relaxed-constexpr) list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/blocking_queue.cc" "runtime/device/gpu/gpu_buffer_mgr.cc") list(REMOVE_ITEM GPU_SRC_LIST "runtime/device/gpu/mpi/mpi_initializer.cc" "runtime/device/gpu/distribution/collective_wrapper.cc" diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc new file mode 100644 index 0000000000..59503128e9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(TopK, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeInt32), + TopKGpuKernel, float, int) +} +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h new file mode 100644 index 0000000000..8b16552c5a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class TopKGpuKernel : public GpuKernel { + public: + TopKGpuKernel() : sorted_(false), outer_size_(1), inner_size_(1), k_(1), use_share_mem_(true), ceil_power2_(0) {} + ~TopKGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspaces, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + S *k = GetDeviceAddress(inputs, 1); + T *output_addr = GetDeviceAddress(outputs, 0); + S *indices = GetDeviceAddress(outputs, 1); + T *data_buff = nullptr; + S *index_buff = nullptr; + if (use_share_mem_ == false) { + data_buff = GetDeviceAddress(workspaces, 0); + index_buff = GetDeviceAddress(workspaces, 1); + } + + TopK(outer_size_, inner_size_, input_addr, k, output_addr, indices, data_buff, index_buff, + reinterpret_cast(stream_ptr)); + + if (sorted_ == false) { + std::cout << "================BitonicSortByKey" << std::endl; + BitonicSortByKey(outer_size_, k_, output_addr, indices, data_buff, index_buff, + reinterpret_cast(stream_ptr)); + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < input_shapes.size() - 1; i++) { + outer_size_ *= input_shapes[i]; + } + inner_size_ = input_shapes[input_shapes.size() - 1]; + k_ = output_shapes[output_shapes.size() - 1]; + + sorted_ = GetAttr(kernel_node, "sorted"); + + ceil_power2_ = RoundUpPower2(inner_size_); + size_t buffer_size = ceil_power2_ * (sizeof(T) + sizeof(S)); + if (buffer_size > SHARED_MEM_PER_BLOCK) { + use_share_mem_ = false; + MS_LOG(WARNING) << "CUDA share memory not enough, sort with RAM"; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(outer_size_ * inner_size_ * sizeof(T)); + input_size_list_.push_back(sizeof(S)); + output_size_list_.push_back(outer_size_ * k_ * sizeof(T)); + output_size_list_.push_back(outer_size_ * k_ * sizeof(S)); + if (use_share_mem_ == false) { + workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(T)); + workspace_size_list_.push_back(outer_size_ * ceil_power2_ * sizeof(S)); + } + } + + private: + bool sorted_; + int outer_size_; + int inner_size_; + int k_; + bool use_share_mem_; + int ceil_power2_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // TopKpuKernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu new file mode 100644 index 0000000000..6e5ac52903 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu @@ -0,0 +1,162 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh" +#include +#include + +int RoundUpPower2(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} + +template +__inline__ __device__ void Swap(T *lhs, T *rhs) { + T tmp = lhs[0]; + lhs[0] = rhs[0]; + rhs[0] = tmp; +} + +template +__global__ void TopkKernel(const int outer, const int inner, const int ceil_power2, const T *input, const S *k, + T *output, S *indices, T *data_buff, S *index_buff) { + // default: sort with share memory + extern __shared__ T share_mem[]; + T *data_arr = share_mem; + S *index_arr = reinterpret_cast(data_arr + ceil_power2); + // sort with RAM + if (data_buff != nullptr && index_buff != nullptr) { + data_arr = data_buff + blockIdx.x * ceil_power2; + index_arr = index_buff + blockIdx.x * ceil_power2; + } + + for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { + data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits::max(); + index_arr[i] = i; + } + __syncthreads(); + + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (data_arr[tid] > data_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } else { + if (data_arr[tid] < data_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } + } + } + __syncthreads(); + } + } + + for (size_t tid = threadIdx.x; tid < k[0]; tid += blockDim.x) { + output[blockIdx.x * k[0] + tid] = data_arr[inner - tid - 1]; + indices[blockIdx.x * k[0] + tid] = index_arr[inner - tid - 1]; + } +} + +template +void TopK(const int &outer, const int &inner, const T *input, const S *k, T *output, S *indices, T *data_buff, + S *index_buff, cudaStream_t stream) { + int ceil_power2 = RoundUpPower2(inner); + int share_mem = (data_buff == nullptr) ? ceil_power2 * (sizeof(T) + sizeof(S)) : 0; + int thread = std::min(ceil_power2, GET_THREADS); + TopkKernel<<>>(outer, inner, ceil_power2, input, k, output, indices, data_buff, + index_buff); +} + +template +__global__ void BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input, + S *indices, T *data_buff, S *index_buff) { + // default: sort with share memory + extern __shared__ T share_mem[]; + T *data_arr = share_mem; + S *index_arr = reinterpret_cast(data_arr + ceil_power2); + // sort with RAM + if (data_buff != nullptr && index_buff != nullptr) { + data_arr = data_buff + blockIdx.x * ceil_power2; + index_arr = index_buff + blockIdx.x * ceil_power2; + } + + for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) { + data_arr[i] = (i < inner) ? input[blockIdx.x * inner + i] : std::numeric_limits::max(); + index_arr[i] = (i < inner) ? indices[blockIdx.x * inner + i] : std::numeric_limits::max();; + } + __syncthreads(); + + for (size_t i = 2; i <= ceil_power2; i <<= 1) { + for (size_t j = (i >> 1); j > 0; j >>= 1) { + for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) { + size_t tid_comp = tid ^ j; + if (tid_comp > tid) { + if ((tid & i) == 0) { + if (index_arr[tid] > index_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } else { + if (index_arr[tid] < index_arr[tid_comp]) { + Swap(&data_arr[tid], &data_arr[tid_comp]); + Swap(&index_arr[tid], &index_arr[tid_comp]); + } + } + } + } + __syncthreads(); + } + } + + for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) { + input[blockIdx.x * inner + tid] = data_arr[tid]; + indices[blockIdx.x * inner + tid] = index_arr[tid]; + } +} + +template +void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff, + cudaStream_t stream) { + int ceil_power2 = RoundUpPower2(inner); + size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S)); + if (share_mem > SHARED_MEM_PER_BLOCK) { + share_mem = 0; + } else { + data_buff = nullptr; + index_buff = nullptr; + } + int thread = std::min(ceil_power2, GET_THREADS); + BitonicSortByKeyKernel<<>>(outer, inner, ceil_power2, input, indices, data_buff, + index_buff); +} + +template void TopK(const int &outer, const int &inner, const float *input_addr, const int *k, float *output, + int *indices, float *data_buff, int *index_buff, cudaStream_t stream); +template void BitonicSortByKey(const int &outer, const int &inner, float *input, int *indices, float *data_buff, + int *index_buff, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh new file mode 100644 index 0000000000..014044296a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" + +template +void TopK(const int &outer, const int &inner, const T *input_addr, const S *k, T *output, S *indices, T *data_buff, + S *index_buff, cudaStream_t stream); + +template +void BitonicSortByKey(const int &outer, const int &inner, T *input, S *indices, T *data_buff, S *index_buff, + cudaStream_t stream); +int RoundUpPower2(int v); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_common.h b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h index 2689fdbaca..ffe237ab6b 100644 --- a/mindspore/ccsrc/runtime/device/gpu/cuda_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_common.h @@ -30,6 +30,7 @@ class CudaCommon { inline int blocks_num(const int total_threads) const { return std::min(((total_threads - 1) / threads_per_block_) + 1, max_blocks_); } + size_t share_memory_size() const { return max_share_memory_; } static CudaCommon &GetInstance() { static CudaCommon instance; @@ -44,6 +45,7 @@ class CudaCommon { threads_per_block_ = prop.maxThreadsPerBlock; max_blocks_ = prop.multiProcessorCount; major_sm_ = prop.major; + max_share_memory_ = prop.sharedMemPerBlock; } ~CudaCommon() = default; CudaCommon(const CudaCommon &) = delete; @@ -52,10 +54,12 @@ class CudaCommon { int max_blocks_; int threads_per_block_; int major_sm_; + size_t max_share_memory_; }; #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() +#define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size() #define MINIUM_SM 6 #define RECOMMEND_SM 7 } // namespace gpu diff --git a/tests/st/ops/gpu/test_topk_op.py b/tests/st/ops/gpu/test_topk_op.py new file mode 100644 index 0000000000..83cd8e6403 --- /dev/null +++ b/tests/st/ops/gpu/test_topk_op.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_topk(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + x_np = np.random.rand(3, 4).astype(np.float32) + k = 4 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(3, 4).astype(np.float32) + k = 4 + ms_output = P.TopK(False)(Tensor(x_np), k) + assert np.allclose(ms_output[0].asnumpy(), x_np) + + x_np = np.random.rand(2, 3, 4).astype(np.float32) + k = 2 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 1024).astype(np.float32) + k = 512 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + # sorted elements num greater than max thread per block + x_np = np.random.rand(512, 2048).astype(np.float32) + k = 1 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 2048).astype(np.float32) + k = 2048 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + # sorted elements num greater than max share memory per block + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 1 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 40960 + ms_output = P.TopK(True)(Tensor(x_np), k) + np_output = np.sort(x_np, axis=-1)[..., ::-1][..., 0:k] + assert np.allclose(ms_output[0].asnumpy(), np_output) + + x_np = np.random.rand(512, 40960).astype(np.float32) + k = 40960 + ms_output = P.TopK(False)(Tensor(x_np), k) + assert np.allclose(ms_output[0].asnumpy(), x_np)