|
|
|
@ -33,10 +33,10 @@ class SGDOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto p = framework::EigenVector<T>::Flatten(*param);
|
|
|
|
|
auto g = framework::EigenVector<T>::Flatten(*grad);
|
|
|
|
|
auto o = framework::EigenVector<T>::Flatten(*param_out);
|
|
|
|
|
auto lr = framework::EigenVector<T>::From(*learning_rate);
|
|
|
|
|
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
Eigen::DSizes<int, 2> grad_dsize(grad->dims()[0], grad->dims()[1]);
|
|
|
|
|
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
|
|
|
|
|
o.device(place) = p - lr.broadcast(grad_dsize) * g;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|