|
|
|
@ -25,47 +25,67 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
// input check
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
|
|
|
|
|
"Input(Label) shouldn't be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null");
|
|
|
|
|
"Input(Label) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
|
|
|
|
"Output(X2) shouldn't be null.");
|
|
|
|
|
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
|
|
|
|
|
auto x1_dims = ctx.Input<framework::Tensor>("X1")->dims();
|
|
|
|
|
auto x2_dims = ctx.Input<framework::Tensor>("X2")->dims();
|
|
|
|
|
PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) &&
|
|
|
|
|
(label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
"All inputs must be vector with the same size");
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Activated")->Resize(label_dims);
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
|
|
|
|
|
"All inputs must be vector with the same size.");
|
|
|
|
|
auto act_t = ctx.Output<framework::LoDTensor>("Activated");
|
|
|
|
|
auto out_t = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
if (act_t) {
|
|
|
|
|
act_t->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
if (out_t) {
|
|
|
|
|
out_t->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
MarginRankLossOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X1", "The first variable to be ranked, row vector.");
|
|
|
|
|
AddInput("X2", "The second variable to be ranked, row vector.");
|
|
|
|
|
AddInput("X1",
|
|
|
|
|
"(2-D tensor with shape [batch_size x 1]) In pairwise ranking, "
|
|
|
|
|
"X1 is the score for one item to be ranked.");
|
|
|
|
|
AddInput("X2",
|
|
|
|
|
"(2-D tensor with shape [batch_size x 1]) In pairwise ranking, "
|
|
|
|
|
"X2 is the score for another item to be ranked.");
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"The label indicating X1 ranked higher than X2 "
|
|
|
|
|
"or not, row vector.");
|
|
|
|
|
AddAttr<AttrType>("margin", "Margin for MarginRankLossOp, scalar.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
"(2-D tensor with shape [batch_size x 1]) "
|
|
|
|
|
"The label indicating X1 ranked higher than X2 or not, "
|
|
|
|
|
"can only be +1 or -1.");
|
|
|
|
|
AddAttr<T>("margin", "(scalar, default 0) Margin for MarginRankLossOp.")
|
|
|
|
|
.SetDefault(static_cast<T>(0));
|
|
|
|
|
AddOutput("Activated",
|
|
|
|
|
"Intermediate tensor to indicate whether each element of "
|
|
|
|
|
"Output(Out) is activated.")
|
|
|
|
|
"(2-D tensor with shape [batch_size x 1]) Intermediate tensor "
|
|
|
|
|
"to indicate whether each element of Output(Out) is activated.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("Out", "The output loss of MarginRankLoss operator");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(2-D tensor with shape [batch_size x 1])"
|
|
|
|
|
"The output loss of MarginRankLoss operator");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`}
|
|
|
|
|
and the `Label` with attribute `margin`, where `Label = 1` indicating X1 is
|
|
|
|
|
and the `Label` with attribute `margin`, where `Label = +1` indicating X1 is
|
|
|
|
|
ranked higher than `X2`, otherwise `Label = -1`. The loss turns out
|
|
|
|
|
|
|
|
|
|
loss(X1, X2, Label) = max(0, -Label * (X1 - X2) + margin)
|
|
|
|
|
|
|
|
|
|
For batch input, `X1`, `X2` and `Label` all have the same size batch_size x 1.
|
|
|
|
|
The attribute `margin` involved here helps make the predictions more robust.
|
|
|
|
|
Only when the difference between `X1` and `X2` is greater than `margin`, it is
|
|
|
|
|
possible for these two items contribute to the final loss.
|
|
|
|
|
|
|
|
|
|
For batch input with size `batch_size`, `X1`, `X2` and `Label`
|
|
|
|
|
all have the same shape [batch_size x 1].
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|