From 5b7790a2a79fff1195a0d178f38a612948a20e91 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 15 Apr 2020 09:11:07 +0800 Subject: [PATCH] Gpu Slice kernel performance improvement --- .../kernel/gpu/arrays/slice_gpu_kernel.h | 5 +- .../ccsrc/kernel/gpu/cuda_impl/slice_impl.cu | 62 ++++++++++--------- .../ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh | 7 ++- tests/st/ops/gpu/test_slice.py | 19 ++++++ 4 files changed, 60 insertions(+), 33 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h index 96e899da60..e70c403cfd 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h @@ -41,8 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel { CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, reinterpret_cast(stream_ptr)); } else { - CalSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, output, - reinterpret_cast(stream_ptr)); + Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0], + input_shape_[1], input_shape_[2], input_shape_[3], input, output, + reinterpret_cast(stream_ptr)); } return true; } diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu index 7f5bcdf81f..78a52149ae 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu @@ -21,11 +21,22 @@ #include "kernel/gpu/cuda_impl/slice_impl.cuh" template -__global__ void Slice(const T* input, int p, int start, int length, T* output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { - output[p + pos] = input[start + pos]; +__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) { + int i = pos / (l2 * l3 * l4) % l1; + int j = pos / (l3 * l4) % l2; + int k = pos / l4 % l3; + int o = pos % l4; + + int offset = (i + s1) * (d2 * d3 * d4) + + (j + s2) * (d3 * d4) + + (k + s3) * d4 + + (o + s4); + output[pos] = input[offset]; } - return; } template __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { @@ -64,22 +75,12 @@ void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaSt return; } template -void CalSlice(const size_t input_size, const T* input, const std::vector in_shape, const std::vector begin, - const std::vector size, T* output, cudaStream_t cuda_stream) { - int block = in_shape[1] * in_shape[2] * in_shape[3]; - int map = in_shape[2] * in_shape[3]; - int w = in_shape[3]; - int length = size[3]; - int p = 0; - for (int i = begin[0]; i < size[0] + begin[0]; i++) { - for (int j = begin[1]; j < size[1] + begin[1]; j++) { - for (int k = begin[2]; k < size[2] + begin[2]; k++) { - Slice<<>>(input, p, i * block + j * map + k * w + begin[3], - length, output); - p = p + size[3]; - } - } - } +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output, cudaStream_t stream) { + Slice4D<<>>(s1, s2, s3, s4, l1, l2, l3, l4, + d1, d2, d3, d4, input, output); } template void CalSliceGrad(const size_t input_size, const T* dy, const std::vector in_shape, const std::vector begin, @@ -147,9 +148,10 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector } template void FillDeviceArray(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); -template void CalSlice(const size_t input_size, const float* input, const std::vector in_shape, - const std::vector begin, const std::vector size, float* output, - cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const float *input, float *output, cudaStream_t stream); template void CalSliceGrad(const size_t input_size, const float* dy, const std::vector in_shape, const std::vector begin, const std::vector size, float* output, cudaStream_t cuda_stream); @@ -160,9 +162,10 @@ template void CalStridedSliceGrad(const size_t input_size, const float* d const std::vector begin, const std::vector end, const std::vector strides, float* dx, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); -template void CalSlice(const size_t input_size, const half* input, const std::vector in_shape, - const std::vector begin, const std::vector size, half* output, - cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const half *input, half *output, cudaStream_t stream); template void CalSliceGrad(const size_t input_size, const half* dy, const std::vector in_shape, const std::vector begin, const std::vector size, half* output, cudaStream_t cuda_stream); @@ -173,9 +176,10 @@ template void CalStridedSliceGrad(const size_t input_size, const half* dy, const std::vector begin, const std::vector end, const std::vector strides, half* dx, cudaStream_t cuda_stream); template void FillDeviceArray(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); -template void CalSlice(const size_t input_size, const int* input, const std::vector in_shape, - const std::vector begin, const std::vector size, int* output, - cudaStream_t cuda_stream); +template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const int *input, int *output, cudaStream_t stream); template void CalSliceGrad(const size_t input_size, const int* dy, const std::vector in_shape, const std::vector begin, const std::vector size, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh index d88ce29c51..9513d6ed24 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh @@ -21,9 +21,12 @@ #include #include "device/gpu/cuda_common.h" + template -void CalSlice(const size_t input_size, const T* input, const std::vector in_shape, const std::vector begin, - const std::vector size, T* output, cudaStream_t cuda_stream); +void Slice4DKernel(const int s1, const int s2, const int s3, const int s4, + const int l1, const int l2, const int l3, const int l4, + const int d1, const int d2, const int d3, const int d4, + const T *input, T *output, cudaStream_t stream); template void CalSliceGrad(const size_t input_size, const T* input, const std::vector in_shape, const std::vector begin, const std::vector size, T* output, cudaStream_t cuda_stream); diff --git a/tests/st/ops/gpu/test_slice.py b/tests/st/ops/gpu/test_slice.py index 95aed4ffa1..9846399481 100644 --- a/tests/st/ops/gpu/test_slice.py +++ b/tests/st/ops/gpu/test_slice.py @@ -43,3 +43,22 @@ def test_slice(): slice = Slice() output = slice(x) assert (output.asnumpy() == expect).all() + + +class SliceNet(nn.Cell): + def __init__(self): + super(SliceNet, self).__init__() + self.slice = P.Slice() + + def construct(self, x): + return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224)) + +def test_slice_4d(): + x_np = np.random.randn(32, 24, 224, 224).astype(np.float32) + output_np = x_np[:, 11:18, :, :] + + x_ms = Tensor(x_np) + net = SliceNet() + output_ms = net(x_ms) + + assert (output_ms.asnumpy() == output_np).all()