|
|
|
@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(Tensor, default Tensor<float>) "
|
|
|
|
|
"Input learning rate");
|
|
|
|
|
|
|
|
|
|
AddOutput("ParamOut", "(Tensor) Output updated parameter");
|
|
|
|
|
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
|
|
|
|
|
AddOutput("ParamOut",
|
|
|
|
|
"(Tensor) This output is updated parameter. "
|
|
|
|
|
"It shared memory with Input(Param).");
|
|
|
|
|
AddOutput("VelocityOut",
|
|
|
|
|
"(Tensor) This output is updated velocity. "
|
|
|
|
|
"It shared memory with Input(Velocity).");
|
|
|
|
|
|
|
|
|
|
AddAttr<float>("mu", "(float) Momentum coefficient");
|
|
|
|
|
AddAttr<bool>("use_nesterov",
|
|
|
|
|