|
|
@ -23,19 +23,22 @@ class SGDOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("param"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
"Input(param) of SGDOp should not be null.");
|
|
|
|
"Input(Param) of SGDOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("grad"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
"Input(grad) of SGDOp should not be null.");
|
|
|
|
"Input(Grad) of SGDOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("learning_rate"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
|
|
|
"Input(learning_rate) of SGDOp should not be null.");
|
|
|
|
"Input(LearningRate) of SGDOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("param_out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
"Output(param_out) of SGDOp should not be null.");
|
|
|
|
"Output(ParamOut) of SGDOp should not be null.");
|
|
|
|
|
|
|
|
|
|
|
|
auto param_dim = ctx->GetInputDim("param");
|
|
|
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
|
|
|
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("grad"),
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
|
|
|
"Learning rate should have 1 element");
|
|
|
|
|
|
|
|
auto param_dim = ctx->GetInputDim("Param");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
|
|
|
|
"Two input of SGD Op's dimension must be same.");
|
|
|
|
"Two input of SGD Op's dimension must be same.");
|
|
|
|
ctx->SetOutputDim("param_out", param_dim);
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -43,10 +46,10 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("param", "input parameter");
|
|
|
|
AddInput("Param", "Input parameter");
|
|
|
|
AddInput("learning_rate", "learning rate of sgd");
|
|
|
|
AddInput("LearningRate", "Learning rate of SGD");
|
|
|
|
AddInput("grad", "input gradient");
|
|
|
|
AddInput("Grad", "Input gradient");
|
|
|
|
AddOutput("param_out", "output parameter");
|
|
|
|
AddOutput("ParamOut", "output parameter");
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
|
|
Simplest sgd algorithm.
|
|
|
|
Simplest sgd algorithm.
|
|
|
|