|
|
@ -4,7 +4,7 @@
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel {
|
|
|
|
auto right_dims = ctx->GetInputDim("Right");
|
|
|
|
auto right_dims = ctx->GetInputDim("Right");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
|
|
|
|
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
|
|
|
|
"All inputs must have the same size");
|
|
|
|
"All inputs must have the same size.");
|
|
|
|
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
"All inputs must be row vector with size batch_size x 1.");
|
|
|
|
(label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
|
|
|
"All inputs must be 2-D tensors with shape [batch_size x 1].");
|
|
|
|
ctx->SetOutputDim("Out", label_dims);
|
|
|
|
ctx->SetOutputDim("Out", label_dims);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("Label",
|
|
|
|
AddInput("Label",
|
|
|
|
"The label indicating A ranked higher than B or not, row vector.");
|
|
|
|
"(2-D Tensor with shape [batch_size x 1]) "
|
|
|
|
AddInput("Left", "The output of RankNet for doc A, vector.");
|
|
|
|
"The label indicating A ranked higher than B or not.");
|
|
|
|
AddInput("Right", "The output of RankNet for doc B, vetor.");
|
|
|
|
AddInput("Left",
|
|
|
|
AddOutput("Out", "The output loss of RankLoss operator, vector.");
|
|
|
|
"(2-D Tensor with shape [batch_size x 1]) "
|
|
|
|
|
|
|
|
"The output of RankNet for doc A.");
|
|
|
|
|
|
|
|
AddInput("Right",
|
|
|
|
|
|
|
|
"(2-D Tensor with shape [batch_size x 1]) "
|
|
|
|
|
|
|
|
"The output of RankNet for doc B.");
|
|
|
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
|
|
|
"(2-D Tensor with shape [batch_size x 1]) "
|
|
|
|
|
|
|
|
"The output loss of RankLoss operator.");
|
|
|
|
AddComment(R"DOC(
|
|
|
|
AddComment(R"DOC(
|
|
|
|
RankLoss Operator.
|
|
|
|
RankLoss Operator.
|
|
|
|
|
|
|
|
|
|
|
@ -65,16 +73,17 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
|
|
|
|
the input pair.
|
|
|
|
the input pair.
|
|
|
|
|
|
|
|
|
|
|
|
The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label
|
|
|
|
The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label
|
|
|
|
(P_{i,j}), which represent the output of RankNet for the two docs and the label,
|
|
|
|
(P_{i,j}), which represent the output score of RankNet for the two docs and
|
|
|
|
respectively, and yields the rank loss C_{i,j} using the following equation:
|
|
|
|
the label respectively, and yields the rank loss C_{i,j} using the following
|
|
|
|
|
|
|
|
equation:
|
|
|
|
|
|
|
|
|
|
|
|
\f$$
|
|
|
|
$$
|
|
|
|
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\
|
|
|
|
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\
|
|
|
|
o_{i,j} = o_i - o_j \\
|
|
|
|
o_{i,j} = o_i - o_j \\
|
|
|
|
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
|
|
|
|
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
|
|
|
|
\f$$
|
|
|
|
$$
|
|
|
|
|
|
|
|
|
|
|
|
The operator can take inputs of one sample or in batch.
|
|
|
|
The operator can take batch inputs with size batch_size (batch_size >= 1).
|
|
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
)DOC");
|
|
|
|
}
|
|
|
|
}
|
|
|
|