diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index 2c6ffd618a..efa0b59992 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -57,25 +57,30 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { MomentumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Param", "Input parameter"); - AddInput("Grad", "Input gradient"); - AddInput("Velocity", "Input velocity"); - AddInput("LearningRate", "Input learning rate"); - - AddOutput("ParamOut", "Output parameter"); - AddOutput("VelocityOut", "Output velocity"); - - AddAttr("mu", "Momentum coefficient"); + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter that has to be updated"); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter"); + AddInput("Velocity", + "(Tensor, default Tensor) " + "Input velocity (corresponding to the parameter) " + "that has to be updated"); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "Input learning rate"); + + AddOutput("ParamOut", "(Tensor) Output updated parameter"); + AddOutput("VelocityOut", "(Tensor) Output updated velocity"); + + AddAttr("mu", "(float) Momentum coefficient"); AddComment(R"DOC( Momentum Algorithm (momentum). -velocity_out = mu * velocity - learning_rate * grad -param_out = param + velocity_out - -Ref: Sutskever, Ilya, et al. "On the importance of initialization - and momentum in deep learning." ICML 2013; - http://jmlr.org/proceedings/papers/v28/sutskever13.pdf +velocity = mu * velocity + gradient +param = param - learning_rate * velocity )DOC"); } diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index 60ff2b7590..fa3788a8ab 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -36,16 +36,16 @@ class MomentumOpKernel : public framework::OpKernel { float mu = ctx.Attr("mu"); - auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto g = EigenVector::Flatten(*ctx.Input("Grad")); - auto v = EigenVector::Flatten(*ctx.Input("Velocity")); - float lr = ctx.Input("LearningRate")->data()[0]; + auto param = EigenVector::Flatten(*ctx.Input("Param")); + auto grad = EigenVector::Flatten(*ctx.Input("Grad")); + auto velocity = EigenVector::Flatten(*ctx.Input("Velocity")); + float learning_rate = ctx.Input("LearningRate")->data()[0]; auto p_out = EigenVector::Flatten(*param_out); auto v_out = EigenVector::Flatten(*velocity_out); auto place = ctx.GetEigenDevice(); - v_out.device(place) = mu * v - lr * g; - p_out.device(place) = p + v_out; + v_out.device(place) = velocity * mu + grad; + p_out.device(place) = param - learning_rate * v_out; } }; diff --git a/python/paddle/v2/framework/tests/test_momentum_op.py b/python/paddle/v2/framework/tests/test_momentum_op.py index cb455bdc9f..d3353ff6e4 100644 --- a/python/paddle/v2/framework/tests/test_momentum_op.py +++ b/python/paddle/v2/framework/tests/test_momentum_op.py @@ -22,8 +22,8 @@ class TestMomentumOp(OpTest): self.attrs = {'mu': mu} - velocity_out = mu * velocity - learning_rate * grad - param_out = param + velocity_out + velocity_out = mu * velocity + grad + param_out = param - learning_rate * velocity_out self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}