|
|
|
@ -44,15 +44,11 @@ class MomentumOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto g = framework::EigenVector<T>::Flatten(*grad);
|
|
|
|
|
auto* lr = learning_rate->data<T>();
|
|
|
|
|
|
|
|
|
|
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
|
|
|
|
|
|
|
|
|
|
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
|
|
|
|
|
|
|
|
|
|
v_out.device(place) = v * mu + g;
|
|
|
|
|
v_out = v * mu + g;
|
|
|
|
|
if (use_nesterov) {
|
|
|
|
|
p_out.device(place) = p - (g - v_out * mu) * lr[0];
|
|
|
|
|
p_out = p - (g - v_out * mu) * lr[0];
|
|
|
|
|
} else {
|
|
|
|
|
p_out.device(place) = p - lr[0] * v_out;
|
|
|
|
|
p_out = p - lr[0] * v_out;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|