!8537 Remove use_locking for GPU ScatterAdd

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
pull/8537/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 357a28d74d

@ -43,8 +43,7 @@ class ScatterAddKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CalScatterAdd(inner_size_, indices_size_, use_locking_, indices, updates, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalScatterAdd(inner_size_, indices_size_, indices, updates, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -72,7 +71,6 @@ class ScatterAddKernel : public GpuKernel {
indices_size_ *= indices_shape[i];
}
updates_size_ = indices_size_ * inner_size_;
use_locking_ = GetAttr<bool>(kernel_node, "use_locking");
InitSizeLists();
return true;
}
@ -82,7 +80,6 @@ class ScatterAddKernel : public GpuKernel {
inner_size_ = 0;
indices_size_ = 0;
updates_size_ = 0;
use_locking_ = true;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
@ -101,7 +98,6 @@ class ScatterAddKernel : public GpuKernel {
int inner_size_;
int indices_size_;
int updates_size_;
bool use_locking_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

@ -18,31 +18,27 @@
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh"
template <typename T>
__global__ void ScatterAdd(const int inner_size, const int updates_size, const bool use_locking, const int *indices,
const T *updates, T *output) {
__global__ void ScatterAdd(const int inner_size, const int updates_size, const int *indices, const T *updates,
T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
const size_t current_pos = indices[index] * inner_size + offset;
if (use_locking) {
MsAtomicAdd(&output[current_pos], updates[pos]);
} else {
output[current_pos] += updates[pos];
}
MsAtomicAdd(&output[current_pos], updates[pos]);
}
}
template <typename T>
void CalScatterAdd(const int &inner_size, const int &indices_size, const bool &use_locking, const int *indices,
const T *updates, T *output, cudaStream_t cuda_stream) {
void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output,
cudaStream_t cuda_stream) {
const int updates_size = inner_size * indices_size;
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, use_locking,
indices, updates, output);
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, indices, updates,
output);
}
template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const bool &use_locking,
const int *indices, const float *updates, float *output, cudaStream_t cuda_stream);
template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const bool &use_locking,
const int *indices, const half *updates, half *output, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const bool &use_locking,
const int *indices, const int *updates, int *output, cudaStream_t cuda_stream);
template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const int *indices,
const float *updates, float *output, cudaStream_t cuda_stream);
template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const int *indices,
const half *updates, half *output, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const int *indices, const int *updates,
int *output, cudaStream_t cuda_stream);

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalScatterAdd(const int &inner_size, const int &indices_size, const bool &use_locking, const int *indices,
const T *updates, T *output, cudaStream_t cuda_stream);
void CalScatterAdd(const int &inner_size, const int &indices_size, const int *indices, const T *updates, T *output,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_

Loading…
Cancel
Save