|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
@ -21,25 +22,31 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
struct SparseAdagradFunctor {
|
|
|
|
|
void operator()(const DeviceContext& context,
|
|
|
|
|
const framework::SelectedRows& grad,
|
|
|
|
|
const framework::Tensor& learning_rate, T epsilon,
|
|
|
|
|
framework::Tensor* moment, framework::Tensor* param);
|
|
|
|
|
void operator()(const DeviceContext &context,
|
|
|
|
|
const framework::SelectedRows &grad,
|
|
|
|
|
const framework::Tensor &learning_rate, T epsilon,
|
|
|
|
|
framework::Tensor *moment, framework::Tensor *param);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class AdagradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
|
|
|
|
|
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
const auto *param_var = ctx.InputVar("Param");
|
|
|
|
|
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
|
|
|
|
|
"The Var(%s)'s type should be LoDTensor, "
|
|
|
|
|
"but the received is %s",
|
|
|
|
|
ctx.Inputs("Param").front(), param_var->Type().name());
|
|
|
|
|
|
|
|
|
|
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
|
|
|
|
|
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
|
|
|
|
|
|
|
|
|
|
param_out_tensor->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
|
|
|
|
|
|
|
|
|
auto* grad_var = ctx.InputVar("Grad");
|
|
|
|
|
auto *grad_var = ctx.InputVar("Grad");
|
|
|
|
|
if (grad_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
auto param = framework::EigenVector<T>::Flatten(
|
|
|
|
|
*ctx.Input<framework::Tensor>("Param"));
|
|
|
|
@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
*ctx.Input<framework::Tensor>("Grad"));
|
|
|
|
|
auto moment = framework::EigenVector<T>::Flatten(
|
|
|
|
|
*ctx.Input<framework::Tensor>("Moment"));
|
|
|
|
|
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
|
|
|
|
auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
|
|
|
|
|
|
|
|
|
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
|
|
|
|
|
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
|
|
|
|
|
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto *place = ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
|
|
|
|
|
moment_out.device(*place) = moment + grad * grad;
|
|
|
|
|
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
|
|
|
|
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto* lr = learning_rate->data<T>();
|
|
|
|
|
auto *lr = learning_rate->data<T>();
|
|
|
|
|
param_out.device(*place) =
|
|
|
|
|
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
|
|
|
|
|
} else {
|
|
|
|
@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
|
|
|
|
|
}
|
|
|
|
|
} else if (grad_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
|
|
|
|
|
auto *param_tensor = ctx.Input<framework::Tensor>("Param");
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
|
|
|
|
|
|
|
|
|
|
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
|
|
|
|
|
auto *moment_tensor = ctx.Input<framework::Tensor>("Moment");
|
|
|
|
|
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
|
|
|
|
|
|
|
|
|
|
SparseAdagradFunctor<DeviceContext, T> functor;
|
|
|
|
|