|
|
|
@ -18,10 +18,9 @@
|
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh"
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const int updates_size,
|
|
|
|
|
const bool use_locking, const T *input, const int *indices, const T *updates, T *output) {
|
|
|
|
|
__global__ void ScatterAdd(const int inner_size, const int updates_size, const bool use_locking, const int *indices,
|
|
|
|
|
const T *updates, T *output) {
|
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
|
|
|
|
output[pos] = input[pos];
|
|
|
|
|
const size_t index = pos / inner_size;
|
|
|
|
|
const size_t offset = pos % inner_size;
|
|
|
|
|
const size_t current_pos = indices[index] * inner_size + offset;
|
|
|
|
@ -34,19 +33,16 @@ __global__ void ScatterAdd(const int input_size, const int inner_size, const int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
|
|
|
|
|
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
|
|
|
|
|
void CalScatterAdd(const int &inner_size, const int &indices_size, const bool &use_locking, const int *indices,
|
|
|
|
|
const T *updates, T *output, cudaStream_t cuda_stream) {
|
|
|
|
|
const int updates_size = inner_size * indices_size;
|
|
|
|
|
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(
|
|
|
|
|
input_size, inner_size, indices_size, updates_size, use_locking, input, indices, updates, output);
|
|
|
|
|
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, use_locking,
|
|
|
|
|
indices, updates, output);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void CalScatterAdd<float>(const int &input_size, const int &inner_size, const int &indices_size,
|
|
|
|
|
const bool &use_locking, const float *input, const int *indices,
|
|
|
|
|
const float *updates, float *output, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterAdd<half>(const int &input_size, const int &inner_size, const int &indices_size,
|
|
|
|
|
const bool &use_locking, const half *input, const int *indices, const half *updates,
|
|
|
|
|
half *output, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterAdd<int>(const int &input_size, const int &inner_size, const int &indices_size,
|
|
|
|
|
const bool &use_locking, const int *input, const int *indices, const int *updates,
|
|
|
|
|
int *output, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const bool &use_locking,
|
|
|
|
|
const int *indices, const float *updates, float *output, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const bool &use_locking,
|
|
|
|
|
const int *indices, const half *updates, half *output, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const bool &use_locking,
|
|
|
|
|
const int *indices, const int *updates, int *output, cudaStream_t cuda_stream);
|
|
|
|
|