Refine the Eigen usage for CPU implementation.

release/0.11.0
dangqingqing 8 years ago
parent 5bd1e73f5e
commit e03b574e0e

@ -44,15 +44,11 @@ class MomentumOpKernel : public framework::OpKernel<T> {
auto g = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<T>();
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
v_out.device(place) = v * mu + g;
v_out = v * mu + g;
if (use_nesterov) {
p_out.device(place) = p - (g - v_out * mu) * lr[0];
p_out = p - (g - v_out * mu) * lr[0];
} else {
p_out.device(place) = p - lr[0] * v_out;
p_out = p - lr[0] * v_out;
}
}
};

Loading…
Cancel
Save