From f9ef6d15190604e4a0780e76c7254a7875e7352e Mon Sep 17 00:00:00 2001 From: peterzhang2029 Date: Thu, 4 Jan 2018 16:11:50 +0800 Subject: [PATCH] init --- paddle/operators/adagrad_op.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/operators/adagrad_op.h b/paddle/operators/adagrad_op.h index 0d77dbcbac..667c1939a2 100644 --- a/paddle/operators/adagrad_op.h +++ b/paddle/operators/adagrad_op.h @@ -47,8 +47,8 @@ class AdagradOpKernel : public framework::OpKernel { *ctx.Input("Grad")); auto moment = framework::EigenVector::Flatten( *ctx.Input("Moment")); - auto lr = framework::EigenVector::Flatten( - *ctx.Input("LearningRate")); + auto* learning_rate = ctx.Input("LearningRate"); + auto* lr = learning_rate->data(); auto param_out = framework::EigenVector::Flatten(*param_out_tensor); auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); @@ -57,7 +57,7 @@ class AdagradOpKernel : public framework::OpKernel { moment_out.device(*place) = moment + grad * grad; Eigen::DSizes m_dsize(moment_out_tensor->numel()); param_out.device(*place) = - param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + param - lr[0] * grad / (moment_out.sqrt() + epsilon); } else if (grad_var->IsType()) { auto* param_tensor = ctx.Input("Param"); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);