|
|
@ -23,33 +23,33 @@ class AdagradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("param"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
"Input(param) of AdagradOp should not be null.");
|
|
|
|
"Input(Param) of AdagradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("grad"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
"Input(grad) of AdagradOp should not be null.");
|
|
|
|
"Input(Grad) of AdagradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("moment"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment"),
|
|
|
|
"Input(moment) of AdagradOp should not be null.");
|
|
|
|
"Input(Moment) of AdagradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
|
|
|
"Input(learning_rate) of AdagradOp should not be null.");
|
|
|
|
"Input(LearningRate) of AdagradOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
"Output(param_out) of AdagradOp should not be null.");
|
|
|
|
"Output(ParamOut) of AdagradOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("moment_out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
|
|
|
|
"Output(moment_out) of AdagradOp should not be null.");
|
|
|
|
"Output(MomentOut) of AdagradOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
auto lr_dims = ctx->GetInputDim("learning_rate");
|
|
|
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
"learning_rate should have one element");
|
|
|
|
"LearningRate should have one element");
|
|
|
|
auto param_dim = ctx->GetInputDim("param");
|
|
|
|
auto param_dims = ctx->GetInputDim("Param");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
param_dim, ctx->GetInputDim("grad"),
|
|
|
|
param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
"Param and grad input of AdagradOp should have the same dimension.");
|
|
|
|
"Param and Grad input of AdagradOp should have the same dimension.");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
param_dim, ctx->GetInputDim("moment"),
|
|
|
|
param_dims, ctx->GetInputDim("Moment"),
|
|
|
|
"Param and moment input of AdagradOp should have the same dimension.");
|
|
|
|
"Param and Moment input of AdagradOp should have the same dimension.");
|
|
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("param_out", param_dim);
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dims);
|
|
|
|
ctx->SetOutputDim("moment_out", param_dim);
|
|
|
|
ctx->SetOutputDim("MomentOut", param_dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -58,15 +58,18 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
AdagradOpMaker(framework::OpProto *proto,
|
|
|
|
AdagradOpMaker(framework::OpProto *proto,
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("param", "Input parameter");
|
|
|
|
AddInput("Param", "(Tensor) Input parameter");
|
|
|
|
AddInput("grad", "Input gradient");
|
|
|
|
AddInput("Grad", "(Tensor) Input gradient");
|
|
|
|
AddInput("moment", "Second moment");
|
|
|
|
AddInput("Moment", "(Tensor) Second moment");
|
|
|
|
AddInput("learning_rate", "learning rate of adagrad");
|
|
|
|
AddInput("LearningRate", "(Tensor) Learning rate");
|
|
|
|
|
|
|
|
|
|
|
|
AddOutput("param_out", "Output parameter");
|
|
|
|
AddOutput("ParamOut", "(Tensor) Output parameter");
|
|
|
|
AddOutput("moment_out", "Output second moment");
|
|
|
|
AddOutput("MomentOut", "(Tensor) Output second moment");
|
|
|
|
|
|
|
|
|
|
|
|
AddAttr<float>("epsilon", "Constant for numerical stability");
|
|
|
|
AddAttr<float>("epsilon",
|
|
|
|
|
|
|
|
"(float, default 1.0e-6) "
|
|
|
|
|
|
|
|
"Constant for numerical stability")
|
|
|
|
|
|
|
|
.SetDefault(1.0e-6f);
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
|
|
Adaptive Gradient Algorithm (Adagrad).
|
|
|
|
Adaptive Gradient Algorithm (Adagrad).
|
|
|
|