|
|
|
@ -19,11 +19,7 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
class MarginRankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
MarginRankLossOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
@ -35,13 +31,11 @@ class MarginRankLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
|
|
|
|
|
auto x1_dims = ctx.Input<framework::Tensor>("X1")->dims();
|
|
|
|
|
auto x2_dims = ctx.Input<framework::Tensor>("X2")->dims();
|
|
|
|
|
PADDLE_ENFORCE((label_dims.size() == 1) && (x1_dims.size() == 1) &&
|
|
|
|
|
(x2_dims.size() == 1),
|
|
|
|
|
"The rank of all inputs must be 1.");
|
|
|
|
|
PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims),
|
|
|
|
|
"All inputs must have the same size");
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
|
|
|
|
|
PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) &&
|
|
|
|
|
(label_dims.size() == 2) && (label_dims[1] == 1),
|
|
|
|
|
"All inputs must be vector with the same size");
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Activated")->Resize(label_dims);
|
|
|
|
|
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -51,18 +45,27 @@ class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
MarginRankLossOpMaker(framework::OpProto *proto,
|
|
|
|
|
framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("Label", "The label indicating X1 ranked higher than X2 or not.");
|
|
|
|
|
AddInput("X1", "The first input of MarginRankLossOp.");
|
|
|
|
|
AddInput("X2", "The second input of MarginRankLossOp");
|
|
|
|
|
AddAttr<AttrType>("margin", "Margin for MarginRankLossOp").SetDefault(0);
|
|
|
|
|
AddOutput("Out", "The output loss of MarginRankLoss operator");
|
|
|
|
|
AddInput("X1", "The first input of MarginRankLossOp, row vector.");
|
|
|
|
|
AddInput("X2", "The second input of MarginRankLossOp, row vector.");
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"The label indicating X1 ranked higher than X2 "
|
|
|
|
|
"or not, row vector.");
|
|
|
|
|
AddAttr<AttrType>("margin", "Margin for MarginRankLossOp, scalar.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddOutput("Activated",
|
|
|
|
|
"Intermediate tensor to indicate "
|
|
|
|
|
"whether Output(Out) is activated")
|
|
|
|
|
"Intermediate tensor to indicate whether each element of "
|
|
|
|
|
"Output(Out) is activated")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddComment(R"DOC(MarginRankLoss operator
|
|
|
|
|
AddOutput("Out", "The output loss of MarginRankLoss operator");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
|
|
MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`}
|
|
|
|
|
and `Label` with attribuute `margin`, where `Label == 1` indicating X1 is
|
|
|
|
|
ranked higher than `X2`, otherwise `Label == -1`. The loss turns out
|
|
|
|
|
|
|
|
|
|
loss(X1, X2, Label) = max(0, -Label * (X1-X2) + margin)
|
|
|
|
|
|
|
|
|
|
loss(x1, x2, y) = max(0, -label * (x1-x2) + margin)
|
|
|
|
|
For batch input, `X1`, `X2` and `Label` all have the same size batch_size x 1.
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -70,11 +73,7 @@ loss(x1, x2, y) = max(0, -label * (x1-x2) + margin)
|
|
|
|
|
|
|
|
|
|
class MarginRankLossGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
MarginRankLossGradOp(const std::string &type,
|
|
|
|
|
const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|