|
|
|
@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Beta2 power accumulator should have 1 dimension");
|
|
|
|
|
|
|
|
|
|
auto param_dims = ctx->GetInputDim("Param");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of AdamOp should have same dimension");
|
|
|
|
|
if (ctx->GetInputsVarType("Grad")[0] ==
|
|
|
|
|
framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad input of AdamOp should have same dimension");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dims, ctx->GetInputDim("Moment1"),
|
|
|
|
|
"Param and Moment1 input of AdamOp should have same dimension");
|
|
|
|
|