|
|
|
@ -28,18 +28,21 @@ class RankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
// input check
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("P"), "Input(P) shouldn't be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oi"), "Input(Oi) shouldn't be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oj"), "Input(Oj) shouldn't be null");
|
|
|
|
|
auto p_dims = ctx.Input<framework::Tensor>("P")->dims();
|
|
|
|
|
auto oi_dims = ctx.Input<framework::Tensor>("Oi")->dims();
|
|
|
|
|
auto oj_dims = ctx.Input<framework::Tensor>("Oj")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(oi_dims, oj_dims,
|
|
|
|
|
"Input(Oi) and Input(Oj) must have the same size");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
p_dims, oi_dims,
|
|
|
|
|
"Input(P) must have the same size with Input(Oi) & Input(Oj)");
|
|
|
|
|
ctx.Output<framework::Tensor>("Out")->Resize(p_dims);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) shouldn't be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"),
|
|
|
|
|
"Input(Left) shouldn't be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"),
|
|
|
|
|
"Input(Right) shouldn't be null");
|
|
|
|
|
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
|
|
|
|
|
auto left_dims = ctx.Input<framework::Tensor>("Left")->dims();
|
|
|
|
|
auto right_dims = ctx.Input<framework::Tensor>("Right")->dims();
|
|
|
|
|
PADDLE_ENFORCE((label_dims.size() == 1) && (left_dims.size() == 1) &&
|
|
|
|
|
(right_dims.size() == 1),
|
|
|
|
|
"The rank of all inputs must be 1.");
|
|
|
|
|
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
|
|
|
|
|
"All inputs must have the same size");
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -48,14 +51,23 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
RankLossOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("P", "The desired target values for posteriors.");
|
|
|
|
|
AddInput("Oi", "The model output for item i.");
|
|
|
|
|
AddInput("Oj", "The model output for item j.");
|
|
|
|
|
AddOutput("Out", "The output tensor of RankLoss operator.");
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"The label indicating A ranked higher than B or not, 1-D tensor.");
|
|
|
|
|
AddInput("Left", "The output of RankNet for doc A, 1-D tensor.");
|
|
|
|
|
AddInput("Right", "The output of RankNet for doc B, 1-D tensor");
|
|
|
|
|
AddOutput("Out", "The output loss of RankLoss operator, 1-D tensor.");
|
|
|
|
|
AddComment(R"DOC(RankLoss operator
|
|
|
|
|
|
|
|
|
|
A rank loss operator for learning to rank (LTR) task. This operator contains
|
|
|
|
|
three inputs: P, Oi, and Oj, and the rank cost can be expressed as
|
|
|
|
|
Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with
|
|
|
|
|
one training sample consisting of a pair of doc A and B, and the label P
|
|
|
|
|
indicating that A is ranked higher than B or not:
|
|
|
|
|
|
|
|
|
|
P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
|
|
|
|
|
the input pair.
|
|
|
|
|
|
|
|
|
|
The RankLoss operator contains three inputs: Left (o_i), Right (o_j) and Label
|
|
|
|
|
(P_{i,j}), which represent the output of RankNet for two docs and the label
|
|
|
|
|
respectively, and yields the rank loss C_{i,j} by following the expression
|
|
|
|
|
|
|
|
|
|
\f[
|
|
|
|
|
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\
|
|
|
|
@ -63,10 +75,11 @@ three inputs: P, Oi, and Oj, and the rank cost can be expressed as
|
|
|
|
|
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
|
|
|
|
|
\f]
|
|
|
|
|
|
|
|
|
|
A detailed explanation about these notations can be found in
|
|
|
|
|
The operator can take inputs of one sample or in batch.
|
|
|
|
|
|
|
|
|
|
[1]. Chris Burges, Tal Shaked, Erin Renshaw, et al. Learning to
|
|
|
|
|
Rank useing Gradient Descent.
|
|
|
|
|
Rank using Gradient Descent.
|
|
|
|
|
http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -81,15 +94,25 @@ class RankLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("P"), "Input(P) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oi"), "Input(Oi) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Oj"), "Input(Oj) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"),
|
|
|
|
|
"Input(Left) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"),
|
|
|
|
|
"Input(Right) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
|
auto dims = ctx.Input<framework::Tensor>("P")->dims();
|
|
|
|
|
ctx.Output<framework::Tensor>(framework::GradVarName("P"))->Resize(dims);
|
|
|
|
|
ctx.Output<framework::Tensor>(framework::GradVarName("Oi"))->Resize(dims);
|
|
|
|
|
ctx.Output<framework::Tensor>(framework::GradVarName("Oj"))->Resize(dims);
|
|
|
|
|
auto dims = ctx.Input<framework::Tensor>("Left")->dims();
|
|
|
|
|
auto *left_grad =
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("Left"));
|
|
|
|
|
auto *right_grad =
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("Right"));
|
|
|
|
|
if (left_grad) {
|
|
|
|
|
left_grad->Resize(dims);
|
|
|
|
|
}
|
|
|
|
|
if (right_grad) {
|
|
|
|
|
right_grad->Resize(dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|