From 356844da46ddaf23b4e1367262f0272658cb7379 Mon Sep 17 00:00:00 2001 From: TFbunny Date: Mon, 2 Nov 2020 19:01:33 -0500 Subject: [PATCH] refine GPU-ScatterUpdate --- .../gpu/arrays/scatter_update_gpu_kernel.cc | 7 ++ .../gpu/arrays/scatter_update_gpu_kernel.h | 6 +- .../gpu/cuda_impl/scatter_update_impl.cu | 34 +++++----- .../gpu/cuda_impl/scatter_update_impl.cuh | 4 +- tests/st/ops/gpu/test_scatter_update_op.py | 65 ++++++++++++++++++- 5 files changed, 91 insertions(+), 25 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.cc index d3737d6700..dd3087635b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.cc @@ -32,5 +32,12 @@ MS_REG_GPU_KERNEL_ONE(ScatterUpdate, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), ScatterUpdateKernel, half) +MS_REG_GPU_KERNEL_ONE(ScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + ScatterUpdateKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h index 5b2727965f..b88de4ffa7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h @@ -40,8 +40,10 @@ class ScatterUpdateKernel : public GpuKernel { int *indices = GetDeviceAddress(inputs, 1); T *updates = GetDeviceAddress(inputs, 2); T *output = GetDeviceAddress(outputs, 0); - CalScatterUpdate(input_size_, inner_size_, indices_size_, input, indices, updates, output, - reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); + CalScatterUpdate(inner_size_, indices_size_, indices, updates, output, reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu index e0d6898e14..2610288337 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu @@ -17,29 +17,27 @@ #include "backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh" template -__global__ void ScatterUpdate(const int input_size, const int inner_size, const int indices_size, const T *input, - const int *indices, const T *updates, T *output) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) { - output[pos] = input[pos]; +__global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates, + T *output) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { const int index = pos / inner_size; const int offset = pos % inner_size; - for (int i = 0; i < indices_size; i++) { - const int update_pos = i * inner_size + offset; - output[pos] = (indices[i] == index ? updates[update_pos] : output[pos]); - } + const int current_pos = indices[index] * inner_size + offset; + output[current_pos] = updates[pos]; } } template -void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, const T *input, - const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) { - ScatterUpdate<<>>(input_size, inner_size, indices_size, input, - indices, updates, output); +void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, + cudaStream_t cuda_stream) { + const int updates_size = inner_size * indices_size; + ScatterUpdate<<>>(inner_size, updates_size, indices, updates, + output); } -template void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, - const float *input, const int *indices, const float *updates, float *output, - cudaStream_t cuda_stream); -template void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, - const half *input, const int *indices, const half *updates, half *output, - cudaStream_t cuda_stream); +template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, + const float *updates, float *output, cudaStream_t cuda_stream); +template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, + const half *updates, half *output, cudaStream_t cuda_stream); +template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, + const int *updates, int *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh index f11bcb3972..fc59c02e7e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh @@ -20,7 +20,7 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalScatterUpdate(const int &input_size, const int &inner_size, const int &indices_size, const T *input, - const int *indices, const T *updates, T *output, cudaStream_t cuda_stream); +void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output, + cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_ diff --git a/tests/st/ops/gpu/test_scatter_update_op.py b/tests/st/ops/gpu/test_scatter_update_op.py index 98b235484c..d5169357e3 100644 --- a/tests/st/ops/gpu/test_scatter_update_op.py +++ b/tests/st/ops/gpu/test_scatter_update_op.py @@ -75,7 +75,19 @@ def test_scatter_update_float16(): updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float16)) output = scatter_update_net(inputx, indices, updates) expected = np.array([[0., 1., 2.], - [3., 4., 5.]]) + [3., 4., 5.]]).astype(np.float16) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_update_int32(): + inputx = Tensor(np.zeros((2, 3)).astype(np.int32)) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.int32)) + output = scatter_update_net(inputx, indices, updates) + expected = np.array([[0., 1., 2.], + [3., 4., 5.]]).astype(np.int32) np.testing.assert_array_almost_equal(output.asnumpy(), expected) @pytest.mark.level0 @@ -89,7 +101,7 @@ def test_scatter_update_large_float16(): expected = np.array([[69., 70., 71.], [66., 67., 68.], [63., 64., 65.], - [72., 73., 74.]]) + [72., 73., 74.]]).astype(np.float16) np.testing.assert_array_almost_equal(output.asnumpy(), expected) @pytest.mark.level0 @@ -102,5 +114,52 @@ def test_scatter_update_disordered_float16(): output = scatter_update_net(inputx, indices, updates) expected = np.array([[45., 44., 43., 42.], [63., 64., 65., 66.], - [67., 68., 69., 70.]]) + [67., 68., 69., 70.]]).astype(np.float16) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_update_disordered_int32(): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32))) + indices = Tensor(np.array([1, 2]).astype(np.int32)) + updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int32)) + output = scatter_update_net(inputx, indices, updates) + expected = np.array([[45., 44., 43., 42.], + [63., 64., 65., 66.], + [67., 68., 69., 70.]]).astype(np.int32) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_update_large_shape_float16(): + inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.float16)) + indices = Tensor(np.array([1, 0]).astype(np.int32)) + updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.float16))) + output = scatter_update_net(inputx, indices, updates) + expected = np.array([[[[23., 22., 21., 20.], + [19., 18., 17., 16.], + [15., 14., 13., 12.]], + [[11., 10., 9., 8.], + [7., 6., 5., 4.], + [3., 2., 1., 0.]]], + [[[47., 46., 45., 44.], + [43., 42., 41., 40.], + [39., 38., 37., 36.]], + [[35., 34., 33., 32.], + [31., 30., 29., 28.], + [27., 26., 25., 24.]]], + [[[48., 49., 50., 51.], + [52., 53., 54., 55.], + [56., 57., 58., 59.]], + [[60., 61., 62., 63.], + [64., 65., 66., 67.], + [68., 69., 70., 71.]]], + [[[72., 73., 74., 75.], + [76., 77., 78., 79.], + [80., 81., 82., 83.]], + [[84., 85., 86., 87.], + [88., 89., 90., 91.], + [92., 93., 94., 95.]]]]).astype(np.float16) np.testing.assert_array_almost_equal(output.asnumpy(), expected)