|
|
|
@ -21,24 +21,24 @@ class HuberLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must be initialized.");
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
|
|
|
|
|
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims(), y->dims());
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2,
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims, y_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
|
|
|
|
|
"The rank of Input(X) must be 2 and the shape is "
|
|
|
|
|
"[batch_size, 1].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x->dims()[1], 1,
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims[1], 1,
|
|
|
|
|
"Each row of Input(X) contains a real value, "
|
|
|
|
|
"so the 2nd dimension of Input(X) must be 1.");
|
|
|
|
|
|
|
|
|
|
ctx.Output<Tensor>("Residual")->Resize(x->dims());
|
|
|
|
|
ctx.Output<Tensor>("Out")->Resize({x->dims()[0], 1});
|
|
|
|
|
ctx->SetOutputDim("Residual", x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", {x_dims[0], 1});
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -55,7 +55,7 @@ class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"The target value of huber loss op."
|
|
|
|
|
"Y is a 2-D tensor with shape [batch_size, 1].");
|
|
|
|
|
AddOutput("Residual",
|
|
|
|
|
"Intermediate tensor to cache residual value of Y and X."
|
|
|
|
|
"Intermediate tensor to cache residual value between Y and X."
|
|
|
|
|
"The shape is same as Input(X) and will be reused in backward.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
@ -82,25 +82,30 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext& ctx) const override {
|
|
|
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* y = ctx.Input<Tensor>("Y");
|
|
|
|
|
auto* residual = ctx.Input<Tensor>("Residual");
|
|
|
|
|
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(x, "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(y, "Input(Y) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(residual, "Input(Residual) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@GRAD) should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(residual->dims(), x->dims());
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims());
|
|
|
|
|
|
|
|
|
|
if (x_grad) x_grad->Resize(x->dims());
|
|
|
|
|
if (y_grad) y_grad->Resize(y->dims());
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Residual"),
|
|
|
|
|
"Input(Residual) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto y_dims = ctx->GetInputDim("Y");
|
|
|
|
|
auto residual_dims = ctx->GetInputDim("Residual");
|
|
|
|
|
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
|
|
|
|
|
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
auto y_grad_name = framework::GradVarName("Y");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput(y_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(y_grad_name, y_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|