|
|
|
@ -19,32 +19,25 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class SGDOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto param = ctx.Input<Tensor>("Param");
|
|
|
|
|
auto grad = ctx.Input<Tensor>("Grad");
|
|
|
|
|
auto param_out = ctx.Output<Tensor>("ParamOut");
|
|
|
|
|
auto learning_rate = ctx.Input<Tensor>("LearningRate");
|
|
|
|
|
auto param = ctx.Input<framework::Tensor>("Param");
|
|
|
|
|
auto grad = ctx.Input<framework::Tensor>("Grad");
|
|
|
|
|
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
|
|
|
|
|
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
|
|
|
|
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto p = EigenVector<T>::Flatten(*param);
|
|
|
|
|
auto g = EigenVector<T>::Flatten(*grad);
|
|
|
|
|
auto o = EigenVector<T>::Flatten(*param_out);
|
|
|
|
|
auto lr = EigenScalar<T>::From(*learning_rate);
|
|
|
|
|
auto p = framework::EigenVector<T>::Flatten(*param);
|
|
|
|
|
auto g = framework::EigenVector<T>::Flatten(*grad);
|
|
|
|
|
auto o = framework::EigenVector<T>::Flatten(*param_out);
|
|
|
|
|
auto lr = framework::EigenVector<T>::From(*learning_rate);
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
o.device(place) = p - lr * g;
|
|
|
|
|
Eigen::DSizes<int, 2> grad_dsize(grad->dims()[0], grad->dims()[1]);
|
|
|
|
|
o.device(place) = p - lr.broadcast(grad_dsize) * g;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|