commit
c8f1389bea
@ -0,0 +1,177 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/positive_negative_pair_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class PositiveNegativePairOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Score"),
|
||||
"Input(Score) of PositiveNegativePairOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Label"),
|
||||
"Input(Label) of PositiveNegativePairOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("QueryID"),
|
||||
"Input(QueryID) of PositiveNegativePairOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("PositivePair"),
|
||||
"Output(PositivePair) of PositiveNegativePairOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("NegativePair"),
|
||||
"Output(NegativePair) of PositiveNegativePairOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("NeutralPair"),
|
||||
"Output(NeutralPair) of PositiveNegativePairOp should not be null.");
|
||||
auto scalar_dim = framework::make_ddim({1});
|
||||
if (ctx->HasInput("AccumulatePositivePair") ||
|
||||
ctx->HasInput("AccumulateNegativePair") ||
|
||||
ctx->HasInput("AccumulateNeutralPair")) {
|
||||
PADDLE_ENFORCE(ctx->HasInput("AccumulatePositivePair") &&
|
||||
ctx->HasInput("AccumulateNegativePair") &&
|
||||
ctx->HasInput("AccumulateNeutralPair"),
|
||||
"All optional inputs(AccumulatePositivePair, "
|
||||
"AccumulateNegativePair, AccumulateNeutralPair) of "
|
||||
"PositiveNegativePairOp are required if one of them is "
|
||||
"specified.");
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulatePositivePair"), scalar_dim,
|
||||
"Shape of AccumulatePositivePair should be {1}.");
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNegativePair"), scalar_dim,
|
||||
"Shape of AccumulateNegativePair should be {1}.");
|
||||
PADDLE_ENFORCE_EQ(ctx->GetInputDim("AccumulateNeutralPair"), scalar_dim,
|
||||
"Shape of AccumulateNeutralPair should be {1}.");
|
||||
}
|
||||
|
||||
auto score_dim = ctx->GetInputDim("Score");
|
||||
auto label_dim = ctx->GetInputDim("Label");
|
||||
auto query_dim = ctx->GetInputDim("QueryID");
|
||||
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
label_dim[0], score_dim[0],
|
||||
"Tensor Score and Label should have the same height (batch size).");
|
||||
PADDLE_ENFORCE_EQ(label_dim[1], 1,
|
||||
"The width of Label should be 1, i.e. each item should "
|
||||
"have a scalar label.");
|
||||
PADDLE_ENFORCE(query_dim == label_dim,
|
||||
"QueryID should have the same shape as Label.");
|
||||
if (ctx->HasInput("Weight")) {
|
||||
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
|
||||
"Weight should have the same shape as Label.");
|
||||
}
|
||||
int column = ctx->Attrs().Get<int>("column");
|
||||
auto depth = score_dim[1];
|
||||
PADDLE_ENFORCE(column < depth && column >= -depth,
|
||||
"Attribute column should be in the range of [-%l, %l)",
|
||||
depth, depth);
|
||||
|
||||
ctx->SetOutputDim("PositivePair", scalar_dim);
|
||||
ctx->SetOutputDim("NegativePair", scalar_dim);
|
||||
ctx->SetOutputDim("NeutralPair", scalar_dim);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::DataType IndicateDataType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::ToDataType(ctx.Input<Tensor>("Score")->type());
|
||||
}
|
||||
};
|
||||
|
||||
class PositiveNegativePairOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
PositiveNegativePairOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("Score",
|
||||
"(Tensor, float) Model Score on an item (with "
|
||||
"respect to QueryID). It's a 2-D tensor with shape [batch_size, "
|
||||
"depth], where the column specified by the attribute \"column\" "
|
||||
"is used as item score.");
|
||||
AddInput("Label",
|
||||
"(Tensor, float) Label of an item (with repsect to "
|
||||
"QueryId). It's a 2-D tensor with shape [batch_size, 1].");
|
||||
AddInput("QueryID",
|
||||
"(Tensor, int64) Query ID that indicates the context. Its shape "
|
||||
"should be the same as Label.");
|
||||
AddInput(
|
||||
"AccumulatePositivePair",
|
||||
"(float) Optional. The accumulated number of positive pairs over a "
|
||||
"stream of data. If provided, the output PositivePair will be "
|
||||
"initialized with this number rather than 0. it won't be modified "
|
||||
"in place.")
|
||||
.AsDispensable();
|
||||
AddInput(
|
||||
"AccumulateNegativePair",
|
||||
"(float) Optional. The accumulated number of negative pairs over a "
|
||||
"stream of data. If provided, the output NegativePair will be "
|
||||
"initialized with this number rather than 0. it won't be modified "
|
||||
"in place.")
|
||||
.AsDispensable();
|
||||
AddInput("AccumulateNeutralPair",
|
||||
"(float) Optional. The accumulated number of neutral pairs over a "
|
||||
"stream of data. If provided, the output NeutralPair will be "
|
||||
"initialized with this number rather than 0. it won't be modified "
|
||||
"in place.")
|
||||
.AsDispensable();
|
||||
AddInput("Weight",
|
||||
"(float) Optional. Weight of current item. If specified, its "
|
||||
"shape should be the same as Label, and the meaning of the output "
|
||||
"changes from numbers of pairs to the total sum of pairs' "
|
||||
"weights. Weight of a pair of items is the average of their "
|
||||
"weights.")
|
||||
.AsDispensable();
|
||||
AddOutput("PositivePair",
|
||||
"(float) Number of positive pairs, i.e. the pairs of "
|
||||
"items that are ranked correctly.");
|
||||
AddOutput("NegativePair",
|
||||
"(float) Number of negative pairs, i.e. the pairs of "
|
||||
"items that are ranked incorrectly.");
|
||||
AddOutput("NeutralPair",
|
||||
"(float) Number of neutral pairs, i.e. the pairs of items "
|
||||
"that have the same score.")
|
||||
.AsDispensable();
|
||||
AddAttr<int>(
|
||||
"column",
|
||||
"(int, default -1) The column position of Score used to rank items in "
|
||||
"descending order. It must be in the range of [-rank(Score), "
|
||||
"rank(Score)). "
|
||||
"If `dim < 0`, the dim to reduce is `rank + dim`. "
|
||||
"Noting that reducing on the first dim will make the LoD info lost.")
|
||||
.SetDefault(0);
|
||||
AddComment(R"DOC(
|
||||
PositiveNegativePairOp can be used to evaluate Learning To Rank(LTR)
|
||||
model performance.
|
||||
Within some context, e.g. the "query", a LTR model generates scores
|
||||
for a list of items, which gives a partial order of the items.
|
||||
PositiveNegativePairOp takes a list of reference rank order
|
||||
(Input("Label")) and the model generated scores (Input(Score)) as
|
||||
inputs and counts the pairs that ranked correctly and incorrectly.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(positive_negative_pair,
|
||||
ops::PositiveNegativePairOp,
|
||||
ops::PositiveNegativePairOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
positive_negative_pair,
|
||||
ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::PositiveNegativePairKernel<paddle::platform::CPUPlace, double>);
|
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/utils/Logging.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class PositiveNegativePairKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
struct PredictionResult {
|
||||
PredictionResult(T score, T label, T weight)
|
||||
: score(score), label(label), weight(weight) {}
|
||||
T score;
|
||||
T label;
|
||||
T weight;
|
||||
};
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto score_t = context.Input<Tensor>("Score");
|
||||
auto label_t = context.Input<Tensor>("Label");
|
||||
auto query_t = context.Input<Tensor>("QueryID");
|
||||
auto acc_positive_t = context.Input<Tensor>("AccumulatePositivePair");
|
||||
auto acc_negative_t = context.Input<Tensor>("AccumulateNegativePair");
|
||||
auto acc_neutral_t = context.Input<Tensor>("AccumulateNeutralPair");
|
||||
auto positive_t = context.Output<Tensor>("PositivePair");
|
||||
auto negative_t = context.Output<Tensor>("NegativePair");
|
||||
auto neutral_t = context.Output<Tensor>("NeutralPair");
|
||||
auto weight_t = context.Input<Tensor>("Weight");
|
||||
|
||||
auto score = score_t->data<T>();
|
||||
auto label = label_t->data<T>();
|
||||
auto query = query_t->data<int64_t>();
|
||||
const T* weight = nullptr;
|
||||
if (weight_t != nullptr) {
|
||||
weight = weight_t->data<T>();
|
||||
}
|
||||
T* positive = positive_t->mutable_data<T>(context.GetPlace());
|
||||
T* negative = negative_t->mutable_data<T>(context.GetPlace());
|
||||
T* neutral = neutral_t->mutable_data<T>(context.GetPlace());
|
||||
|
||||
auto score_dim = score_t->dims();
|
||||
auto batch_size = score_dim[0];
|
||||
auto width = score_dim[1];
|
||||
auto column = context.Attr<int32_t>("column");
|
||||
if (column < 0) {
|
||||
column += width;
|
||||
}
|
||||
|
||||
// construct document instances for each query: Query => List[<score#0,
|
||||
// label#0, weight#0>, ...]
|
||||
std::unordered_map<int64_t, std::vector<PredictionResult>> predictions;
|
||||
for (auto i = 0; i < batch_size; ++i) {
|
||||
if (predictions.find(query[i]) == predictions.end()) {
|
||||
predictions.emplace(
|
||||
std::make_pair(query[i], std::vector<PredictionResult>()));
|
||||
}
|
||||
predictions[query[i]].emplace_back(score[i * width + column], label[i],
|
||||
weight_t != nullptr ? weight[i] : 1.0);
|
||||
}
|
||||
|
||||
// for each query, accumulate pair counts
|
||||
T pos = 0, neg = 0, neu = 0;
|
||||
if (acc_positive_t != nullptr && acc_negative_t != nullptr &&
|
||||
acc_neutral_t != nullptr) {
|
||||
pos = acc_positive_t->data<T>()[0];
|
||||
neg = acc_negative_t->data<T>()[0];
|
||||
neu = acc_neutral_t->data<T>()[0];
|
||||
}
|
||||
auto evaluate_one_list = [&pos, &neg,
|
||||
&neu](std::vector<PredictionResult> vec) {
|
||||
for (auto ite1 = vec.begin(); ite1 != vec.end(); ++ite1) {
|
||||
for (auto ite2 = ite1 + 1; ite2 != vec.end(); ++ite2) {
|
||||
if (ite1->label == ite2->label) { // labels are equal, ignore.
|
||||
continue;
|
||||
}
|
||||
T w = (ite1->weight + ite2->weight) * 0.5;
|
||||
if (ite1->score == ite2->score) {
|
||||
neu += w;
|
||||
}
|
||||
(ite1->score - ite2->score) * (ite1->label - ite2->label) > 0.0
|
||||
? pos += w
|
||||
: neg += w;
|
||||
}
|
||||
}
|
||||
};
|
||||
for (auto prediction : predictions) {
|
||||
evaluate_one_list(prediction.second);
|
||||
}
|
||||
*positive = pos;
|
||||
*negative = neg;
|
||||
*neutral = neu;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,106 @@
|
||||
import unittest
|
||||
import itertools
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def py_pnpair_op(score, label, query, column=-1, weight=None):
|
||||
# group by query id
|
||||
predictions = {}
|
||||
batch_size = label.shape[0]
|
||||
if weight is None:
|
||||
weight = np.ones(shape=(batch_size, 1)).astype('float32')
|
||||
for s, l, q, w in zip(score, label, query, weight):
|
||||
s, l, q, w = s[column], l[0], q[0], w[0]
|
||||
if q not in predictions:
|
||||
predictions[q] = []
|
||||
predictions[q].append((s, l, w))
|
||||
|
||||
# accumulate statistics
|
||||
pos, neg, neu = 0, 0, 0
|
||||
for _, ranks in predictions.items():
|
||||
for e1, e2 in itertools.combinations(ranks, 2):
|
||||
s1, s2, l1, l2, w1, w2 = e1[0], e2[0], e1[1], e2[1], e1[2], e2[2]
|
||||
w = (w1 + w2) * 0.5
|
||||
if l1 == l2:
|
||||
continue
|
||||
if s1 == s2:
|
||||
neu += w
|
||||
elif (s1 - s2) * (l1 - l2) > 0:
|
||||
pos += w
|
||||
else:
|
||||
neg += w
|
||||
|
||||
return np.array(pos).astype('float32'), np.array(neg).astype(
|
||||
'float32'), np.array(neu).astype('float32')
|
||||
|
||||
|
||||
class TestPositiveNegativePairOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'positive_negative_pair'
|
||||
batch_size = 20
|
||||
max_query_id = 5
|
||||
score = np.random.normal(size=(batch_size, 1)).astype('float32')
|
||||
label = np.random.normal(size=(batch_size, 1)).astype('float32')
|
||||
query = np.array(
|
||||
[np.random.randint(max_query_id) for i in range(batch_size)])
|
||||
query = np.reshape(query, newshape=(batch_size, 1)).astype('int64')
|
||||
|
||||
pos, neg, neu = py_pnpair_op(score, label, query)
|
||||
self.inputs = {'Score': score, 'Label': label, 'QueryID': query}
|
||||
self.attrs = {'column': -1}
|
||||
self.outputs = {
|
||||
'PositivePair': pos,
|
||||
'NegativePair': neg,
|
||||
'NeutralPair': neu
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestPositiveNegativePairOpAccumulateWeight(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'positive_negative_pair'
|
||||
batch_size = 20
|
||||
max_query_id = 5
|
||||
max_random_num = 2 << 15
|
||||
score_dim = 2
|
||||
score = np.random.normal(size=(batch_size, 2)).astype('float32')
|
||||
label = np.random.normal(size=(batch_size, 1)).astype('float32')
|
||||
weight = np.random.normal(size=(batch_size, 1)).astype('float32')
|
||||
query = np.array(
|
||||
[np.random.randint(max_query_id) for i in range(batch_size)])
|
||||
query = np.reshape(query, newshape=(batch_size, 1)).astype('int64')
|
||||
acc_pos = np.reshape(
|
||||
np.random.randint(max_random_num), newshape=(1)).astype('float32')
|
||||
acc_neg = np.reshape(
|
||||
np.random.randint(max_random_num), newshape=(1)).astype('float32')
|
||||
acc_neu = np.reshape(
|
||||
np.random.randint(max_random_num), newshape=(1)).astype('float32')
|
||||
column = np.random.randint(score_dim)
|
||||
|
||||
pos, neg, neu = py_pnpair_op(
|
||||
score, label, query, column=column, weight=weight)
|
||||
self.inputs = {
|
||||
'Score': score,
|
||||
'Label': label,
|
||||
'QueryID': query,
|
||||
'AccumulatePositivePair': acc_pos,
|
||||
'AccumulateNegativePair': acc_neg,
|
||||
'AccumulateNeutralPair': acc_neu,
|
||||
'Weight': weight
|
||||
}
|
||||
self.attrs = {'column': column}
|
||||
self.outputs = {
|
||||
'PositivePair': pos + acc_pos,
|
||||
'NegativePair': neg + acc_neg,
|
||||
'NeutralPair': neu + acc_neu
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue