|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -30,8 +31,10 @@ public:
|
|
|
|
|
|
|
|
|
|
param_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
param_out->flat<T>().device(*(ctx.GetEigenDevice<Place>())) =
|
|
|
|
|
param.flat<T>() - lr * grad.flat<T>();
|
|
|
|
|
framework::EigenVector<T>::Flatten(*param_out)
|
|
|
|
|
.device(*(ctx.GetEigenDevice<Place>())) =
|
|
|
|
|
framework::EigenVector<T>::Flatten(param) -
|
|
|
|
|
lr * framework::EigenVector<T>::Flatten(grad);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|