|
|
|
@ -23,18 +23,17 @@ class BprLossOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LabelPos"),
|
|
|
|
|
"Input(LabelPos) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto label_pos_dims = ctx->GetInputDim("LabelPos");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_pos_dims.size(),
|
|
|
|
|
"Input(X) and Input(LabelPos) shall have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_pos_dims, 0, rank - 1),
|
|
|
|
|
"Input(X) and Input(LabelPos) shall have the same shape "
|
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
|
"Input(X) and Input(Label) shall have the same shape "
|
|
|
|
|
"except the last dimension.");
|
|
|
|
|
|
|
|
|
|
auto y_dims = x_dims;
|
|
|
|
@ -60,25 +59,23 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("LabelPos"),
|
|
|
|
|
"Input(LabelPos) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
|
|
|
|
|
"Input(Y@GRAD) shoudl be not null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
|
|
|
|
"Output(X@GRAD) should be not null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto label_pos_dims = ctx->GetInputDim("LabelPos");
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
|
|
|
|
|
int rank = x_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
|
|
|
|
|
"Input(Y@Grad) and Input(X) should have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
label_pos_dims.size(), rank,
|
|
|
|
|
"Input(LabelPos) and Input(X) should have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
|
|
|
|
|
"Input(Label) and Input(X) should have the same rank.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(label_pos_dims, 0, rank - 1),
|
|
|
|
|
"The Input(X) and Input(LabelPos) should have the same "
|
|
|
|
|
framework::slice_ddim(label_dims, 0, rank - 1),
|
|
|
|
|
"The Input(X) and Input(Label) should have the same "
|
|
|
|
|
"shape except the last dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
|
|
|
|
|
framework::slice_ddim(dy_dims, 0, rank - 1),
|
|
|
|
@ -86,8 +83,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
|
|
|
|
|
"shape except the last dimension.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
|
|
|
|
|
"The last dimension of Input(Y@Grad) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_pos_dims[rank - 1], 1,
|
|
|
|
|
" the last dimension of Input(LabelPos) should be 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
|
|
|
|
|
" the last dimension of Input(Label) should be 1.");
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
ctx->ShareLoD("X", framework::GradVarName("X"));
|
|
|
|
|
}
|
|
|
|
@ -111,7 +108,7 @@ class BprLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"size is equal to the number of classes. This input is a "
|
|
|
|
|
"real number.");
|
|
|
|
|
AddInput(
|
|
|
|
|
"LabelPos",
|
|
|
|
|
"Label",
|
|
|
|
|
"(Tensor), the tensor which represents the ground truth. It has the "
|
|
|
|
|
"same shape with 'X' except the last dimension. the last dimension "
|
|
|
|
|
"size is 1.");
|
|
|
|
@ -122,7 +119,7 @@ class BprLossOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Bayesian Personalized Ranking Loss Operator.
|
|
|
|
|
|
|
|
|
|
This operator belongs to pairwise ranking loss. LabelPos is the desired item.
|
|
|
|
|
This operator belongs to pairwise ranking loss. Label is the desired item.
|
|
|
|
|
The loss at a given point in one session is defined as:
|
|
|
|
|
$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$
|
|
|
|
|
|
|
|
|
|