|
|
|
@ -17,32 +17,32 @@
|
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_update_impl.cuh"
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
__global__ void ScatterUpdate(const int inner_size, const int updates_size, const int *indices, const T *updates,
|
|
|
|
|
__global__ void ScatterUpdate(const size_t inner_size, const size_t updates_size, const int *indices, const T *updates,
|
|
|
|
|
T *input) {
|
|
|
|
|
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
|
|
|
|
const int index = pos / inner_size;
|
|
|
|
|
const int offset = pos % inner_size;
|
|
|
|
|
const int current_pos = indices[index] * inner_size + offset;
|
|
|
|
|
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
|
|
|
|
|
const size_t index = pos / inner_size;
|
|
|
|
|
const size_t offset = pos % inner_size;
|
|
|
|
|
const size_t current_pos = indices[index] * inner_size + offset;
|
|
|
|
|
input[current_pos] = updates[pos];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void CalScatterUpdate(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *input,
|
|
|
|
|
cudaStream_t cuda_stream) {
|
|
|
|
|
const int updates_size = inner_size * indices_size;
|
|
|
|
|
void CalScatterUpdate(const size_t &inner_size, const size_t &indices_size, const int *indices, const T *updates,
|
|
|
|
|
T *input, cudaStream_t cuda_stream) {
|
|
|
|
|
const size_t updates_size = inner_size * indices_size;
|
|
|
|
|
ScatterUpdate<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates,
|
|
|
|
|
input);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template void CalScatterUpdate<float>(const int &inner_size, const int &indices_size, const int *indices,
|
|
|
|
|
template void CalScatterUpdate<float>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
|
|
|
|
const float *updates, float *input, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterUpdate<half>(const int &inner_size, const int &indices_size, const int *indices,
|
|
|
|
|
template void CalScatterUpdate<half>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
|
|
|
|
const half *updates, half *input, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterUpdate<int>(const int &inner_size, const int &indices_size, const int *indices,
|
|
|
|
|
template void CalScatterUpdate<int>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
|
|
|
|
const int *updates, int *input, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterUpdate<unsigned char>(const int &inner_size, const int &indices_size, const int *indices,
|
|
|
|
|
template void CalScatterUpdate<unsigned char>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
|
|
|
|
const unsigned char *updates, unsigned char *input,
|
|
|
|
|
cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterUpdate<int8_t>(const int &inner_size, const int &indices_size, const int *indices,
|
|
|
|
|
const int8_t *updates, int8_t *input, cudaStream_t cuda_stream);
|
|
|
|
|
template void CalScatterUpdate<int8_t>(const size_t &inner_size, const size_t &indices_size, const int *indices,
|
|
|
|
|
const int8_t *updates, int8_t *input, cudaStream_t cuda_stream);
|
|
|
|
|