|
|
|
@ -24,11 +24,13 @@ class RmspropOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
"Input(param) of RmspropOp should not be null.");
|
|
|
|
|
"Input(Param) of RmspropOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
"Input(grad) of RmspropOp should not be null.");
|
|
|
|
|
"Input(Grad) of RmspropOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment"),
|
|
|
|
|
"Input(moment) of RmspropOp should not be null.");
|
|
|
|
|
"Input(Moment) of RmspropOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
|
|
|
|
"Input(LearningRate) of RmspropOp should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
|
"Output(param_out) of RmspropOp should not be null.");
|
|
|
|
@ -43,6 +45,10 @@ class RmspropOp : public framework::OperatorWithKernel {
|
|
|
|
|
param_dim, ctx->GetInputDim("Moment"),
|
|
|
|
|
"Param and moment input of RmspropOp should have the same dimension.");
|
|
|
|
|
|
|
|
|
|
auto lr_dim = ctx->GetInputDim("LearningRate");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
|
|
|
|
|
"Learning Rate should be a scalar.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dim);
|
|
|
|
|
ctx->SetOutputDim("MomentOut", param_dim);
|
|
|
|
|
}
|
|
|
|
@ -56,11 +62,11 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("Param", "Input parameter");
|
|
|
|
|
AddInput("Grad", "Input gradient");
|
|
|
|
|
AddInput("Moment", "Second moment");
|
|
|
|
|
AddInput("LearningRate", "Learning Rate");
|
|
|
|
|
|
|
|
|
|
AddOutput("ParamOut", "Output parameter");
|
|
|
|
|
AddOutput("MomentOut", "Output second moment");
|
|
|
|
|
|
|
|
|
|
AddAttr<float>("learningRate", "Learning rate");
|
|
|
|
|
AddAttr<float>("epsilon", "Constant for numerical stability");
|
|
|
|
|
AddAttr<float>("decayRate", "Decay rate for moving average of gradients");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -68,7 +74,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
RMSprop
|
|
|
|
|
|
|
|
|
|
MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad
|
|
|
|
|
ParamOut = Param - learningRate * Grad / (sqrt(MomentOut) + epsilon)
|
|
|
|
|
ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon)
|
|
|
|
|
|
|
|
|
|
The original slide(Slide 29 of
|
|
|
|
|
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
|
|
|
|