diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc index c3324a56c0..9051d617f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.cc @@ -39,5 +39,19 @@ MS_REG_GPU_KERNEL_ONE(ScatterAdd, .AddInputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), ScatterAddKernel, int) +MS_REG_GPU_KERNEL_ONE(ScatterAdd, + KernelAttr() + .AddInputAttr(kNumberTypeInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt8) + .AddOutputAttr(kNumberTypeInt8), + ScatterAddKernel, int8_t) +MS_REG_GPU_KERNEL_ONE(ScatterAdd, + KernelAttr() + .AddInputAttr(kNumberTypeUInt8) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeUInt8) + .AddOutputAttr(kNumberTypeUInt8), + ScatterAddKernel, uint8_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu index 81172d7866..05bcf15610 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu @@ -42,3 +42,8 @@ template void CalScatterAdd(const size_t &inner_size, const size_t &indice const half *updates, half *input, cudaStream_t cuda_stream); template void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, const int *updates, int *input, cudaStream_t cuda_stream); +template void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, + const unsigned char *updates, unsigned char *input, + cudaStream_t cuda_stream); +template void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, + const int8_t *updates, int8_t *input, cudaStream_t cuda_stream); diff --git a/tests/st/ops/gpu/test_scatter_add_op.py b/tests/st/ops/gpu/test_scatter_add_op.py index 493d00505d..5041a633f3 100644 --- a/tests/st/ops/gpu/test_scatter_add_op.py +++ b/tests/st/ops/gpu/test_scatter_add_op.py @@ -269,6 +269,38 @@ def test_scatter_add_disordered_dynamic_int32(): [492., 496., 500., 504.]]).astype(np.int32) np.testing.assert_array_almost_equal(output.asnumpy(), expected) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_add_disordered_dynamic_int8(): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8))) + indices = Tensor(np.array([[[0, 1, 2], + [2, 1, 0]], + [[0, 0, 0], + [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int8)) + output = scatter_add_d_net(inputx, indices, updates) + expected = np.array([[464., 468., 472., 476.], + [187., 188., 189., 190.], + [492., 496., 500., 504.]]).astype(np.int8) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_scatter_add_disordered_dynamic_uint8(): + inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8))) + indices = Tensor(np.array([[[0, 1, 2], + [2, 1, 0]], + [[0, 0, 0], + [2, 2, 2]]]).astype(np.int32)) + updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.uint8)) + output = scatter_add_d_net(inputx, indices, updates) + expected = np.array([[464., 468., 472., 476.], + [187., 188., 189., 190.], + [492., 496., 500., 504.]]).astype(np.uint8) + np.testing.assert_array_almost_equal(output.asnumpy(), expected) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard