|
|
|
@ -42,12 +42,18 @@ class SGDGpuKernel : public GpuKernel {
|
|
|
|
|
T *accum = GetDeviceAddress<T>(inputs, 3);
|
|
|
|
|
T *momentum = GetDeviceAddress<T>(inputs, 4);
|
|
|
|
|
T *stat = GetDeviceAddress<T>(inputs, 5);
|
|
|
|
|
T *output_param = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
|
|
|
|
|
SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream));
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
|
|
|
|
cudaMemcpyAsync(output_param, param, sizeof(T) * size_, cudaMemcpyDeviceToDevice,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream)),
|
|
|
|
|
"SGD cudaMemcpyAsync params to outputs failed");
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
|
kernel_node_ = kernel_node;
|
|
|
|
|
dampening_ = GetAttr<float>(kernel_node, "dampening");
|
|
|
|
|
weight_decay_ = GetAttr<float>(kernel_node, "weight_decay");
|
|
|
|
|
nesterov_ = GetAttr<bool>(kernel_node, "nesterov");
|
|
|
|
|