|
|
@ -29,12 +29,17 @@ class AdagradOp : public framework::OperatorWithKernel {
|
|
|
|
"Input(grad) of AdagradOp should not be null.");
|
|
|
|
"Input(grad) of AdagradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("moment"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("moment"),
|
|
|
|
"Input(moment) of AdagradOp should not be null.");
|
|
|
|
"Input(moment) of AdagradOp should not be null.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
|
|
|
|
|
|
|
|
"Input(learning_rate) of AdagradOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
|
|
|
|
"Output(param_out) of AdagradOp should not be null.");
|
|
|
|
"Output(param_out) of AdagradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("moment_out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("moment_out"),
|
|
|
|
"Output(moment_out) of AdagradOp should not be null.");
|
|
|
|
"Output(moment_out) of AdagradOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto lr_dims = ctx->GetInputDim("learning_rate");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
|
|
|
"learning_rate should have one element");
|
|
|
|
auto param_dim = ctx->GetInputDim("param");
|
|
|
|
auto param_dim = ctx->GetInputDim("param");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
param_dim, ctx->GetInputDim("grad"),
|
|
|
|
param_dim, ctx->GetInputDim("grad"),
|
|
|
@ -56,11 +61,11 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
AddInput("param", "Input parameter");
|
|
|
|
AddInput("param", "Input parameter");
|
|
|
|
AddInput("grad", "Input gradient");
|
|
|
|
AddInput("grad", "Input gradient");
|
|
|
|
AddInput("moment", "Second moment");
|
|
|
|
AddInput("moment", "Second moment");
|
|
|
|
|
|
|
|
AddInput("learning_rate", "learning rate of adagrad");
|
|
|
|
|
|
|
|
|
|
|
|
AddOutput("param_out", "Output parameter");
|
|
|
|
AddOutput("param_out", "Output parameter");
|
|
|
|
AddOutput("moment_out", "Output second moment");
|
|
|
|
AddOutput("moment_out", "Output second moment");
|
|
|
|
|
|
|
|
|
|
|
|
AddAttr<float>("learning_rate", "Learning rate");
|
|
|
|
|
|
|
|
AddAttr<float>("epsilon", "Constant for numerical stability");
|
|
|
|
AddAttr<float>("epsilon", "Constant for numerical stability");
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
|
|