|
|
@ -109,7 +109,7 @@ class AdamFunctor<T, GPUAdam> {
|
|
|
|
|
|
|
|
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
|
|
|
|
|
|
|
|
|
|
|
|
// Write back to global memory
|
|
|
|
// Write back to global memory
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
@ -181,7 +181,9 @@ class AdamFunctor<T, CPUAdam> {
|
|
|
|
|
|
|
|
|
|
|
|
moment1_out = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
moment1_out = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
param_out = param - lr * (moment1_out / (moment2_out.sqrt() + epsilon_));
|
|
|
|
param_out = param -
|
|
|
|
|
|
|
|
lr * (moment1_out /
|
|
|
|
|
|
|
|
(moment2_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -249,7 +251,7 @@ class SparseAdamFunctor<T, GPUAdam> {
|
|
|
|
|
|
|
|
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
|
|
|
|
|
|
|
|
|
|
|
|
// Write back to global memory
|
|
|
|
// Write back to global memory
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
@ -328,7 +330,7 @@ class SparseAdamFunctor<T, CPUAdam> {
|
|
|
|
|
|
|
|
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
|
|
|
|
p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
|
|
|
|
|
|
|
|
|
|
|
|
// Write back to global memory
|
|
|
|
// Write back to global memory
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
|
moment1_out_[i] = mom1;
|
|
|
|