Merge pull request #6058 from kuke/refine_rank_loss_op

Revise comments in rank_loss_op
release/0.11.0
Yibing Liu 7 years ago committed by GitHub
commit 61d98f27ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel {
auto right_dims = ctx->GetInputDim("Right"); auto right_dims = ctx->GetInputDim("Right");
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size"); "All inputs must have the same size.");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), PADDLE_ENFORCE(
"All inputs must be row vector with size batch_size x 1."); (label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be 2-D tensors with shape [batch_size x 1].");
ctx->SetOutputDim("Out", label_dims); ctx->SetOutputDim("Out", label_dims);
} }
}; };
@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label", AddInput("Label",
"The label indicating A ranked higher than B or not, row vector."); "(2-D Tensor with shape [batch_size x 1]) "
AddInput("Left", "The output of RankNet for doc A, vector."); "The label indicating A ranked higher than B or not.");
AddInput("Right", "The output of RankNet for doc B, vetor."); AddInput("Left",
AddOutput("Out", "The output loss of RankLoss operator, vector."); "(2-D Tensor with shape [batch_size x 1]) "
"The output of RankNet for doc A.");
AddInput("Right",
"(2-D Tensor with shape [batch_size x 1]) "
"The output of RankNet for doc B.");
AddOutput("Out",
"(2-D Tensor with shape [batch_size x 1]) "
"The output loss of RankLoss operator.");
AddComment(R"DOC( AddComment(R"DOC(
RankLoss Operator. RankLoss Operator.
@ -65,16 +73,17 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
the input pair. the input pair.
The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label
(P_{i,j}), which represent the output of RankNet for the two docs and the label, (P_{i,j}), which represent the output score of RankNet for the two docs and
respectively, and yields the rank loss C_{i,j} using the following equation: the label respectively, and yields the rank loss C_{i,j} using the following
equation:
\f$$ $$
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\
o_{i,j} = o_i - o_j \\ o_{i,j} = o_i - o_j \\
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
\f$$ $$
The operator can take inputs of one sample or in batch. The operator can take batch inputs with size batch_size (batch_size >= 1).
)DOC"); )DOC");
} }

@ -4,7 +4,7 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,

@ -4,7 +4,7 @@
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,

Loading…
Cancel
Save