Adapt changes to GPU framework for ScatterAdd/Update and SparseFtrl

pull/8433/head
TFbunny 4 years ago
parent 5caded733e
commit 24ea2ddf92

@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class ScatterAddKernel : public GpuKernel {
public:
ScatterAddKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0), use_locking_(true) {}
ScatterAddKernel() { ResetResource(); }
~ScatterAddKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -43,7 +43,7 @@ class ScatterAddKernel : public GpuKernel {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&output[0], &input[0], input_size_ * sizeof(T), cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
CalScatterAdd(input_size_, inner_size_, indices_size_, use_locking_, input, indices, updates, output,
CalScatterAdd(inner_size_, indices_size_, use_locking_, indices, updates, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
@ -77,6 +77,17 @@ class ScatterAddKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
inner_size_ = 0;
indices_size_ = 0;
updates_size_ = 0;
use_locking_ = true;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));

@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class ScatterUpdateKernel : public GpuKernel {
public:
ScatterUpdateKernel() : input_size_(0), inner_size_(0), indices_size_(0), updates_size_(0) {}
ScatterUpdateKernel() { ResetResource(); }
~ScatterUpdateKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -75,6 +75,16 @@ class ScatterUpdateKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
input_size_ = 0;
inner_size_ = 0;
indices_size_ = 0;
updates_size_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));

@ -18,10 +18,9 @@
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_add_impl.cuh"
template <typename T>
__global__ void ScatterAdd(const int input_size, const int inner_size, const int indices_size, const int updates_size,
const bool use_locking, const T *input, const int *indices, const T *updates, T *output) {
__global__ void ScatterAdd(const int inner_size, const int updates_size, const bool use_locking, const int *indices,
const T *updates, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < updates_size; pos += blockDim.x * gridDim.x) {
output[pos] = input[pos];
const size_t index = pos / inner_size;
const size_t offset = pos % inner_size;
const size_t current_pos = indices[index] * inner_size + offset;
@ -34,19 +33,16 @@ __global__ void ScatterAdd(const int input_size, const int inner_size, const int
}
template <typename T>
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream) {
void CalScatterAdd(const int &inner_size, const int &indices_size, const bool &use_locking, const int *indices,
const T *updates, T *output, cudaStream_t cuda_stream) {
const int updates_size = inner_size * indices_size;
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(
input_size, inner_size, indices_size, updates_size, use_locking, input, indices, updates, output);
ScatterAdd<<<GET_BLOCKS(updates_size), GET_THREADS, 0, cuda_stream>>>(inner_size, updates_size, use_locking,
indices, updates, output);
}
template void CalScatterAdd<float>(const int &input_size, const int &inner_size, const int &indices_size,
const bool &use_locking, const float *input, const int *indices,
const float *updates, float *output, cudaStream_t cuda_stream);
template void CalScatterAdd<half>(const int &input_size, const int &inner_size, const int &indices_size,
const bool &use_locking, const half *input, const int *indices, const half *updates,
half *output, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const int &input_size, const int &inner_size, const int &indices_size,
const bool &use_locking, const int *input, const int *indices, const int *updates,
int *output, cudaStream_t cuda_stream);
template void CalScatterAdd<float>(const int &inner_size, const int &indices_size, const bool &use_locking,
const int *indices, const float *updates, float *output, cudaStream_t cuda_stream);
template void CalScatterAdd<half>(const int &inner_size, const int &indices_size, const bool &use_locking,
const int *indices, const half *updates, half *output, cudaStream_t cuda_stream);
template void CalScatterAdd<int>(const int &inner_size, const int &indices_size, const bool &use_locking,
const int *indices, const int *updates, int *output, cudaStream_t cuda_stream);

@ -20,7 +20,7 @@
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalScatterAdd(const int &input_size, const int &inner_size, const int &indices_size, const bool &use_locking,
const T *input, const int *indices, const T *updates, T *output, cudaStream_t cuda_stream);
void CalScatterAdd(const int &inner_size, const int &indices_size, const bool &use_locking, const int *indices,
const T *updates, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_SCATTER_ADD_IMPL_CUH_

@ -26,20 +26,7 @@ namespace kernel {
template <typename T, typename S>
class SparseFtrlGpuKernel : public GpuKernel {
public:
SparseFtrlGpuKernel()
: variable_size_(0),
accumulation_size_(0),
linear_size_(0),
gradient_size_(0),
indices_size_(0),
lr_(0.0f),
l1_(0.0f),
l2_(0.0f),
lr_power_(0.0f),
use_locking_(false),
num_index_(0),
n_stride_(1) {}
SparseFtrlGpuKernel() { ResetResource(); }
~SparseFtrlGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
@ -110,6 +97,24 @@ class SparseFtrlGpuKernel : public GpuKernel {
return true;
}
void ResetResource() noexcept override {
variable_size_ = 0;
accumulation_size_ = 0;
linear_size_ = 0;
gradient_size_ = 0;
indices_size_ = 0;
lr_ = 0.0f;
l1_ = 0.0f;
l2_ = 0.0f;
lr_power_ = 0.0f;
use_locking_ = false;
num_index_ = 0;
n_stride_ = 1;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);

Loading…
Cancel
Save