|
|
|
@ -23,57 +23,61 @@ class AdamaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
"Input(Param) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
"Input(Grad) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment"),
|
|
|
|
|
"Input(Moment) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("InfNorm"),
|
|
|
|
|
"Input(InfNorm) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
|
|
|
|
"Input(LearningRate) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
|
|
|
|
|
"Input(Beta1Pow) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->GetInputsVarType("Param").front() ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->GetInputsVarType("Grad").front() ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
|
"Output(ParamOut) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
|
|
|
|
|
"Output(MomentOut) of AdamaxOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("InfNormOut"),
|
|
|
|
|
"Output(InfNormOut) of AdamaxOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param", "Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment", "Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("InfNorm"), "Input", "InfNorm", "Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
|
|
|
|
|
"Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Beta1Pow"), "Input", "Beta1Pow", "Adamax");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputsVarType("Param").front(),
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Param").front(),
|
|
|
|
|
ctx->GetInputsVarType("Param").front()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->GetInputsVarType("Grad").front(),
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input var's type should be LoDTensor, but the received is %s",
|
|
|
|
|
ctx->Inputs("Grad").front(),
|
|
|
|
|
ctx->GetInputsVarType("Grad").front()));
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut", "Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
|
|
|
|
|
"Adamax");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("InfNormOut"), "Output", "InfNormOut",
|
|
|
|
|
"Adamax");
|
|
|
|
|
|
|
|
|
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
|
|
|
|
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
|
|
|
|
|
"Maybe the Input variable LearningRate has not "
|
|
|
|
|
"been initialized. You may need to confirm "
|
|
|
|
|
"if you put exe.run(startup_program) "
|
|
|
|
|
"after optimizer.minimize function.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Maybe the Input variable LearningRate has not "
|
|
|
|
|
"been initialized. You may need to confirm "
|
|
|
|
|
"if you put exe.run(startup_program) "
|
|
|
|
|
"after optimizer.minimize function."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
"Learning rate should have 1 dimension");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Learning rate should have 1 dimension"));
|
|
|
|
|
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
|
|
|
|
|
"Beta1 power accumulator should have 1 dimension");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Beta1 power accumulator should have 1 dimension"));
|
|
|
|
|
auto param_dims = ctx->GetInputDim("Param");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of AdamaxOp should have same dimension");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Param and Grad input of AdamaxOp should have same dimension"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Moment"),
|
|
|
|
|
"Param and Moment input of AdamaxOp should have same dimension");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Param and Moment input of AdamaxOp should have same dimension"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("InfNorm"),
|
|
|
|
|
"Param and InfNorm input of AdamaxOp should have same dimension");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Param and InfNorm input of AdamaxOp should have same dimension"));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dims);
|
|
|
|
|
ctx->SetOutputDim("MomentOut", param_dims);
|
|
|
|
|