|
|
|
@ -124,19 +124,20 @@ struct SparseAdamFunctor {
|
|
|
|
|
row_numel_(row_numel) {}
|
|
|
|
|
|
|
|
|
|
inline HOSTDEVICE void operator()(size_t i) const {
|
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
|
for (int64_t j = 0; j < row_numel_; ++j) {
|
|
|
|
|
T g = grad_[i * row_numel_ + j];
|
|
|
|
|
T mom1 = moment1_[rows_[i] * row_numel_ + j];
|
|
|
|
|
T mom2 = moment2_[rows_[i] * row_numel_ + j];
|
|
|
|
|
T lr = *lr_;
|
|
|
|
|
T beta1_pow = *beta1_pow_;
|
|
|
|
|
T beta2_pow = *beta2_pow_;
|
|
|
|
|
T p = param_[rows_[i] * row_numel_ + j];
|
|
|
|
|
|
|
|
|
|
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
|
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
|
|
|
|
|
|
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
|
|
|
|
|
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
|
|
|
|
|
param_out_[rows_[i] * row_numel_ + j] = p;
|
|
|
|
|