diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h index 0b00babc0b..fd1ada1057 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_add_gpu_kernel.h @@ -96,10 +96,10 @@ class ScatterAddKernel : public GpuKernel { } private: - int input_size_; - int inner_size_; - int indices_size_; - int updates_size_; + size_t input_size_; + size_t inner_size_; + size_t indices_size_; + size_t updates_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h index e0ac72561f..b5b673999c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_update_gpu_kernel.h @@ -96,10 +96,10 @@ class ScatterUpdateKernel : public GpuKernel { } private: - int input_size_; - int inner_size_; - int indices_size_; - int updates_size_; + size_t input_size_; + size_t inner_size_; + size_t indices_size_; + size_t updates_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu index 8292ef7b1c..81172d7866 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cu @@ -18,7 +18,7 @@ #include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh" template -__global__ void ScatterAdd(const int inner_size, const int updates_size, const int *indices, const T *updates, +__global__ void ScatterAdd(const size_t inner_size, const size_t updates_size, const int *indices, const T *updates, T *input) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { const size_t index = pos / inner_size; @@ -29,16 +29,16 @@ __global__ void ScatterAdd(const int inner_size, const int updates_size, const i } template -void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, +void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, const T *updates, T *input, cudaStream_t cuda_stream) { - const int updates_size = inner_size * indices_size; + const size_t updates_size = inner_size * indices_size; ScatterAdd<<>>(inner_size, updates_size, indices, updates, input); } -template void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, +template void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, const float *updates, float *input, cudaStream_t cuda_stream); -template void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, +template void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, const half *updates, half *input, cudaStream_t cuda_stream); -template void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const int *updates, - int *input, cudaStream_t cuda_stream); +template void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, + const int *updates, int *input, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh index 996b9ba606..1c54816563 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh @@ -20,7 +20,7 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, +void CalScatterAdd(const size_t &inner_size, const size_t &indices_size, const int *indices, const T *updates, T *input, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu index a6e79d57db..b93c4a68b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cu @@ -17,32 +17,32 @@ #include "backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh" template -__global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates, +__global__ void ScatterUpdate(const size_t inner_size, const size_t updates_size, const int *indices, const T *updates, T *input) { - for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { - const int index = pos / inner_size; - const int offset = pos % inner_size; - const int current_pos = indices[index] * inner_size + offset; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) { + const size_t index = pos / inner_size; + const size_t offset = pos % inner_size; + const size_t current_pos = indices[index] * inner_size + offset; input[current_pos] = updates[pos]; } } template -void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, - cudaStream_t cuda_stream) { - const int updates_size = inner_size * indices_size; +void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const T *updates, + T *input, cudaStream_t cuda_stream) { + const size_t updates_size = inner_size * indices_size; ScatterUpdate<<>>(inner_size, updates_size, indices, updates, input); } -template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, +template void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const float *updates, float *input, cudaStream_t cuda_stream); -template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, +template void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const half *updates, half *input, cudaStream_t cuda_stream); -template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, +template void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const int *updates, int *input, cudaStream_t cuda_stream); -template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, +template void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const unsigned char *updates, unsigned char *input, cudaStream_t cuda_stream); -template void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, - const int8_t *updates, int8_t *input, cudaStream_t cuda_stream); +template void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, + const int8_t *updates, int8_t *input, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh index a10ccc7ae4..94e1b31d47 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh @@ -20,7 +20,7 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input, - cudaStream_t cuda_stream); +void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const T *updates, + T *input, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_UPDATE_IMPL_CUH_