|
|
|
@ -23,21 +23,33 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ModifiedHuberLoss");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ModifiedHuberLoss");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The tensor rank of X must be 2.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2, platform::errors::InvalidArgument(
|
|
|
|
|
"Input(input) rank should be 2, "
|
|
|
|
|
"but received input rank(%d) != 2",
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime() ||
|
|
|
|
|
(framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims, y_dims,
|
|
|
|
|
"The shape of X and Y must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims, y_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input(input) and Input(label) should have the same "
|
|
|
|
|
"shape, but received input shape [%s] != label shape [%s]",
|
|
|
|
|
x_dims, y_dims));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], 1, "The 2nd dimension of X must be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The second dimension of Input(input) should be 1, "
|
|
|
|
|
"but received second dimension of input (%d) != 1",
|
|
|
|
|
x_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("IntermediateVal", x_dims);
|
|
|
|
@ -87,11 +99,11 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"),
|
|
|
|
|
"Intermediate value must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@Grad) must not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ModifiedHuberLossGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("IntermediateVal"), "Input", "IntermediateVal",
|
|
|
|
|
"ModifiedHuberLossGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "ModifiedHuberLossGrad");
|
|
|
|
|
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
|
|
|
|
@ -100,9 +112,20 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
intermediate_dims, y_dims,
|
|
|
|
|
"The shape of X and intermediate value must be the same.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, y_dims,
|
|
|
|
|
"The shape of Input(Out@Grad) and X must be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Intermediate variable which will be reused in "
|
|
|
|
|
"backward processing should the same as "
|
|
|
|
|
"the shape of Input(label), but received Intermediate variable "
|
|
|
|
|
"shape [%s] != label shape [%s]",
|
|
|
|
|
intermediate_dims, y_dims));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_grad_dims, y_dims,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of output gradient should be the same as "
|
|
|
|
|
"the shape of Input(label), but received the output gradient "
|
|
|
|
|
"shape [%s] != label shape [%s]",
|
|
|
|
|
out_grad_dims, y_dims));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|