|
|
|
@ -23,57 +23,56 @@ class AdamOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
// "Input(Param) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
// "Input(Grad) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("Moment1"),
|
|
|
|
|
// "Input(Moment1) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("Moment2"),
|
|
|
|
|
// "Input(Moment2) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
|
|
|
|
// "Input(LearningRate) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
|
|
|
|
|
// "Input(Beta1Pow) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
|
|
|
|
|
// "Input(Beta2Pow) of AdamOp should not be null.");
|
|
|
|
|
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
|
// "Output(ParamOut) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
|
|
|
|
|
// "Output(Moment1Out) of AdamOp should not be null.");
|
|
|
|
|
// PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
|
|
|
|
|
// "Output(Moment2Out) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
"Input(Param) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
"Input(Grad) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
|
|
|
|
|
"Input(Moment1) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
|
|
|
|
|
"Input(Moment2) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
|
|
|
|
|
"Input(LearningRate) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
|
|
|
|
|
"Input(Beta1Pow) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
|
|
|
|
|
"Input(Beta2Pow) of AdamOp should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
|
"Output(ParamOut) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
|
|
|
|
|
"Output(Moment1Out) of AdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
|
|
|
|
|
"Output(Moment2Out) of AdamOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
|
|
|
|
// PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
// "Learning rate should have 1 dimension");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
"Learning rate should have 1 dimension");
|
|
|
|
|
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
|
|
|
|
|
// PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
|
|
|
|
|
// "Beta1 power accumulator should have 1 dimension");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
|
|
|
|
|
"Beta1 power accumulator should have 1 dimension");
|
|
|
|
|
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
|
|
|
|
|
// PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
|
|
|
|
|
// "Beta2 power accumulator should have 1 dimension");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
|
|
|
|
|
"Beta2 power accumulator should have 1 dimension");
|
|
|
|
|
|
|
|
|
|
auto param_dims = ctx->GetInputDim("Param");
|
|
|
|
|
// if (ctx->GetInputsVarType("Grad")[0] ==
|
|
|
|
|
// framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
// PADDLE_ENFORCE_EQ(
|
|
|
|
|
// param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
// "Param and Grad input of AdamOp should have same dimension");
|
|
|
|
|
// }
|
|
|
|
|
// PADDLE_ENFORCE_EQ(
|
|
|
|
|
// param_dims, ctx->GetInputDim("Moment1"),
|
|
|
|
|
// "Param and Moment1 input of AdamOp should have same dimension");
|
|
|
|
|
// PADDLE_ENFORCE_EQ(
|
|
|
|
|
// param_dims, ctx->GetInputDim("Moment2"),
|
|
|
|
|
// "Param and Moment2 input of AdamOp should have same dimension");
|
|
|
|
|
if (ctx->GetInputsVarType("Grad")[0] ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of AdamOp should have same dimension");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Moment1"),
|
|
|
|
|
"Param and Moment1 input of AdamOp should have same dimension");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Moment2"),
|
|
|
|
|
"Param and Moment2 input of AdamOp should have same dimension");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dims);
|
|
|
|
|
ctx->SetOutputDim("Moment1Out", param_dims);
|
|
|
|
|
ctx->SetOutputDim("Moment2Out", param_dims);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
auto input_data_type =
|
|
|
|
|