|
|
|
@ -27,6 +27,8 @@ class SGDOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(param) of SGDOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("grad"),
|
|
|
|
|
"Input(grad) of SGDOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
|
|
|
|
|
"Input(learning_rate) of SGDOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
|
|
|
|
|
"Output(param_out) of SGDOp should not be null.");
|
|
|
|
|
|
|
|
|
@ -42,9 +44,9 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("param", "input parameter");
|
|
|
|
|
AddInput("learning_rate", "learning rate of sgd");
|
|
|
|
|
AddInput("grad", "input gradient");
|
|
|
|
|
AddOutput("param_out", "output parameter");
|
|
|
|
|
AddAttr<float>("learning_rate", "learning rate of sgd");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
Simplest sgd algorithm.
|
|
|
|
|