|
|
|
@ -24,34 +24,42 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Param"),
|
|
|
|
|
"Input(Param) of ProximalAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Moment"),
|
|
|
|
|
"Input(Moment) of ProximalAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Grad"),
|
|
|
|
|
"Input(Grad) of ProximalAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("LearningRate"),
|
|
|
|
|
"Input(LearningRate) of ProximalAdagradOp should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
|
|
|
|
|
"Output(ParamOut) of ProximalAdagradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("MomentOut"),
|
|
|
|
|
"Output(MomentOut) of ProximalAdagradOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Param"), "Input", "Param",
|
|
|
|
|
"ProximalAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Moment"), "Input", "Moment",
|
|
|
|
|
"ProximalAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Grad"), "Input", "Grad", "ProximalAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("LearningRate"), "Input", "LearningRate",
|
|
|
|
|
"ProximalAdagradOp");
|
|
|
|
|
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("ParamOut"), "Output", "ParamOut",
|
|
|
|
|
"ProximalAdagradOp");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("MomentOut"), "Output", "MomentOut",
|
|
|
|
|
"ProximalAdagradOp");
|
|
|
|
|
|
|
|
|
|
auto param_dim = ctx->GetInputDim("Param");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dim, ctx->GetInputDim("Grad"),
|
|
|
|
|
"Param and Grad of ProximalAdagrad Op must have same dimension.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
param_dim, ctx->GetInputDim("Moment"),
|
|
|
|
|
"Param and Moment of ProximalAdagrad Op must have same dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Intput(Param) should be equal to the "
|
|
|
|
|
"Input(Grad) of ProximalAdagrad Op. But received "
|
|
|
|
|
"Input(Param).dimensions=[%s], "
|
|
|
|
|
"Input(Grad).dimensions=[%s]",
|
|
|
|
|
param_dim, ctx->GetInputDim("Grad")));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Intput(Param) should be equal to the "
|
|
|
|
|
"Input(Moment) of ProximalAdagrad Op. But received "
|
|
|
|
|
"Input(Param).dimensions=[%s], "
|
|
|
|
|
"Input(Moment).dimensions=[%s]",
|
|
|
|
|
param_dim, ctx->GetInputDim("Moment")));
|
|
|
|
|
|
|
|
|
|
auto lr_dim = ctx->GetInputDim("LearningRate");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
|
|
|
|
|
"Learning Rate should be a scalar.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
framework::product(lr_dim), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Learning Rate should be a scalar. But received dimension[%s]",
|
|
|
|
|
lr_dim));
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("ParamOut", param_dim);
|
|
|
|
|
ctx->SetOutputDim("MomentOut", param_dim);
|
|
|
|
|