|
|
@ -22,7 +22,7 @@ class AdamaxOp : public framework::OperatorWithKernel {
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
void InferShape(framework::InferShapeContextBase *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
"Input(Param) of AdamaxOp should not be null.");
|
|
|
|
"Input(Param) of AdamaxOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|