optimize the dsize

revert-4814-Add_sequence_project_op
qiaolongfei 7 years ago
parent 775c60246b
commit 8ebc31d935

@ -33,10 +33,10 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::From(*learning_rate);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 2> grad_dsize(grad->dims()[0], grad->dims()[1]);
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
}
};

Loading…
Cancel
Save