|
|
|
@ -23,22 +23,27 @@ class LargeScaleFuseAdamOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
"Input(Grad) of LargeScaleFuseAdamOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Grad"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Grad) of LargeScaleFuseAdamOp should not be null."));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("LearningRate"),
|
|
|
|
|
"Input(LearningRate) of LargeScaleFuseAdamOp should not be null.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(LearningRate) of LargeScaleFuseAdamOp should not be null."));
|
|
|
|
|
|
|
|
|
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
|
|
|
|
|
"Maybe the Input variable LearningRate has not "
|
|
|
|
|
"been initialized. You may need to confirm "
|
|
|
|
|
"if you put exe.run(startup_program) "
|
|
|
|
|
"after optimizer.minimize function.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Maybe the Input variable LearningRate has not "
|
|
|
|
|
"been initialized. You may need to confirm "
|
|
|
|
|
"if you put exe.run(startup_program) "
|
|
|
|
|
"after optimizer.minimize function."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
|
|
|
|
"Learning rate should have 1 element");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Learning rate should have 1 element"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|