|
|
|
@ -23,9 +23,9 @@ __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t b
|
|
|
|
|
const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1,
|
|
|
|
|
S *indices_stride, S *work_shape) {
|
|
|
|
|
int i, j;
|
|
|
|
|
for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
|
|
|
|
|
for (size_t read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
|
|
|
|
|
read_index += blockDim.x * gridDim.x) {
|
|
|
|
|
int write_index = 0;
|
|
|
|
|
size_t write_index = 0;
|
|
|
|
|
bool out_bound = false;
|
|
|
|
|
|
|
|
|
|
i = read_index / block_size;
|
|
|
|
@ -51,8 +51,8 @@ void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const
|
|
|
|
|
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
|
|
|
|
S *work_shape, cudaStream_t stream) {
|
|
|
|
|
ScatterNdKernel<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size,
|
|
|
|
|
output_size, indices_dim_0, indices_dim_1,
|
|
|
|
|
indices_stride, work_shape);
|
|
|
|
|
output_size, indices_dim_0, indices_dim_1,
|
|
|
|
|
indices_stride, work_shape);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -60,21 +60,43 @@ template void ScatterNd<float, int>(int *indices, float *update, float *output,
|
|
|
|
|
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<float, int64_t>(int64_t *indices, float *update, float *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size,
|
|
|
|
|
const size_t &indices_dim_0, const size_t &indices_dim_1,
|
|
|
|
|
int64_t *indices_stride, int64_t *work_shape, cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<half, int64_t>(int64_t *indices, half *update, half *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<int, int64_t>(int64_t *indices, int *update, int *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int64_t *indices_stride, int64_t *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
// NOLINTNEXTLINE
|
|
|
|
|
template void ScatterNd<short, int>(int *indices, short *update, short *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
// NOLINTNEXTLINE
|
|
|
|
|
template void ScatterNd<short, int64_t>(int64_t *indices, short *update, short *output, const size_t &block_size,
|
|
|
|
|
const size_t &input_size, const size_t &output_size,
|
|
|
|
|
const size_t &indices_dim_0, const size_t &indices_dim_1,
|
|
|
|
|
int64_t *indices_stride, int64_t *work_shape, cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<unsigned char, int>(int *indices, unsigned char *update, unsigned char *output,
|
|
|
|
|
const size_t &block_size, const size_t &input_size,
|
|
|
|
|
const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
|
|
|
|
cudaStream_t stream);
|
|
|
|
|
template void ScatterNd<unsigned char, int64_t>(int64_t *indices, unsigned char *update, unsigned char *output,
|
|
|
|
|
const size_t &block_size, const size_t &input_size,
|
|
|
|
|
const size_t &output_size, const size_t &indices_dim_0,
|
|
|
|
|
const size_t &indices_dim_1, int64_t *indices_stride,
|
|
|
|
|
int64_t *work_shape, cudaStream_t stream);
|
|
|
|
|