|
|
|
@ -23,46 +23,54 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
"Input(Param) of DecayedAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
"Input(Grad) of DecayedAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment"),
|
|
|
|
|
"Input(Moment) of DecayedAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("LearningRate"),
|
|
|
|
|
"Input(LearningRate) of DecayedAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->GetInputsVarType("Param").front() ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->GetInputsVarType("Grad").front() ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
|
"Output(ParamOut) of DecayedAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
|
|
|
|
|
"Output(MomentOut) of DecayedAdagradOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param",
|
|
|
|
|
"DecayedAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "DecayedAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment",
|
|
|
|
|
"DecayedAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
|
|
|
|
|
"DecayedAdagradOp");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputsVarType("Param").front(),
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Param").front(),
|
|
|
|
|
ctx->GetInputsVarType("Param").front()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputsVarType("Grad").front(),
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Grad").front(),
|
|
|
|
|
ctx->GetInputsVarType("Grad").front()));
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut",
|
|
|
|
|
"DecayedAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
|
|
|
|
|
"DecayedAdagradOp");
|
|
|
|
|
|
|
|
|
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
|
|
|
|
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
|
|
|
|
|
"Maybe the Input variable LearningRate has not "
|
|
|
|
|
"been initialized. You may need to confirm "
|
|
|
|
|
"if you put exe.run(startup_program) "
|
|
|
|
|
"after optimizer.minimize function.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Maybe the Input variable LearningRate has not "
|
|
|
|
|
"been initialized. You may need to confirm "
|
|
|
|
|
"if you put exe.run(startup_program) "
|
|
|
|
|
"after optimizer.minimize function."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
"LearningRate should have one element");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"LearningRate should have one element"));
|
|
|
|
|
auto param_dims = ctx->GetInputDim("Param");
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of DecayedAdagradOp should have "
|
|
|
|
|
"the same dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Moment"),
|
|
|
|
|
"Param and Moment input of DecayedAdagradOp should have "
|
|
|
|
|
"the same dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Param and Grad input of DecayedAdagradOp should have "
|
|
|
|
|
"the same dimension."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Moment"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Param and Moment input of DecayedAdagradOp should have "
|
|
|
|
|
"the same dimension."));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dims);
|
|
|
|
|
ctx->SetOutputDim("MomentOut", param_dims);
|
|
|
|
|