|
|
|
@ -33,19 +33,35 @@ class SparseFtrlGpuKernel : public GpuKernel {
|
|
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
|
|
|
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
|
|
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
|
|
|
|
void *stream_ptr) override {
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
|
|
|
|
T *variable = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
|
|
|
|
T *linear = GetDeviceAddress<T>(inputs, 2);
|
|
|
|
|
T *gradient = GetDeviceAddress<T>(inputs, 3);
|
|
|
|
|
S *indices = GetDeviceAddress<S>(inputs, 4);
|
|
|
|
|
T *variable_out = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
T *accumulation_out = GetDeviceAddress<T>(outputs, 1);
|
|
|
|
|
T *linear_out = GetDeviceAddress<T>(outputs, 2);
|
|
|
|
|
CalSparseApplyFtrl(gradient, indices, num_index_, n_stride_, lr_, l1_, l2_, lr_power_, use_locking_, variable,
|
|
|
|
|
accumulation, linear, reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cudaMemcpyAsync(&variable_out[0], &variable[0], variable_size_, cudaMemcpyDeviceToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync output failed");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cudaMemcpyAsync(&accumulation_out[0], &accumulation[0], accumulation_size_,
|
|
|
|
|
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync output failed");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cudaMemcpyAsync(&linear_out[0], &linear[0], linear_size_, cudaMemcpyDeviceToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync output failed");
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
|
kernel_node_ = kernel_node;
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
|
if (input_num != 5) {
|
|
|
|
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but sparse ftrl needs 5 inputs.";
|
|
|
|
@ -104,9 +120,9 @@ class SparseFtrlGpuKernel : public GpuKernel {
|
|
|
|
|
input_size_list_.push_back(linear_size_);
|
|
|
|
|
input_size_list_.push_back(gradient_size_);
|
|
|
|
|
input_size_list_.push_back(indices_size_);
|
|
|
|
|
output_size_list_.push_back(0);
|
|
|
|
|
output_size_list_.push_back(0);
|
|
|
|
|
output_size_list_.push_back(0);
|
|
|
|
|
output_size_list_.push_back(variable_size_);
|
|
|
|
|
output_size_list_.push_back(accumulation_size_);
|
|
|
|
|
output_size_list_.push_back(linear_size_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ResetResource() noexcept override {
|
|
|
|
|