|
|
|
@ -21,6 +21,8 @@ namespace operators {
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using LoDTensor = framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
constexpr char kEPS = 1e-6;
|
|
|
|
|
|
|
|
|
|
class BipartiteMatchOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -41,12 +43,13 @@ template <typename T>
|
|
|
|
|
class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
// The match_indices must be initialized to -1 at first.
|
|
|
|
|
// The match_dis must be initialized to 0 at first.
|
|
|
|
|
void BipartiteMatch(const Tensor& dis, int* match_indices,
|
|
|
|
|
T* match_dis) const {
|
|
|
|
|
int64_t row = dis.dims()[0];
|
|
|
|
|
int64_t col = dis.dims()[1];
|
|
|
|
|
auto* dis_data = dis.data<T>();
|
|
|
|
|
// The match_dist must be initialized to 0 at first.
|
|
|
|
|
void BipartiteMatch(const Tensor& dist, int* match_indices,
|
|
|
|
|
T* match_dist) const {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
|
|
|
|
|
int64_t row = dist.dims()[0];
|
|
|
|
|
int64_t col = dist.dims()[1];
|
|
|
|
|
auto* dist_data = dist.data<T>();
|
|
|
|
|
std::vector<int> row_pool;
|
|
|
|
|
for (int i = 0; i < row; ++i) {
|
|
|
|
|
row_pool.push_back(i);
|
|
|
|
@ -54,7 +57,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
while (row_pool.size() > 0) {
|
|
|
|
|
int max_idx = -1;
|
|
|
|
|
int max_row_idx = -1;
|
|
|
|
|
T max_dis = -1;
|
|
|
|
|
T max_dist = -1;
|
|
|
|
|
for (int64_t j = 0; j < col; ++j) {
|
|
|
|
|
if (match_indices[j] != -1) {
|
|
|
|
|
continue;
|
|
|
|
@ -62,13 +65,13 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (int k = 0; k < row_pool.size(); ++k) {
|
|
|
|
|
int m = row_pool[k];
|
|
|
|
|
// distance is 0 between m-th row and j-th column
|
|
|
|
|
if (dis_data[m * col + j] < 1e-6) {
|
|
|
|
|
if (dist_data[m * col + j] < kEPS) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (dis_data[m * col + j] > max_dis) {
|
|
|
|
|
if (dist_data[m * col + j] > max_dist) {
|
|
|
|
|
max_idx = j;
|
|
|
|
|
max_row_idx = m;
|
|
|
|
|
max_dis = dis_data[m * col + j];
|
|
|
|
|
max_dist = dist_data[m * col + j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -78,7 +81,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
|
|
|
|
|
match_indices[max_idx] = max_row_idx;
|
|
|
|
|
match_dis[max_idx] = max_dis;
|
|
|
|
|
match_dist[max_idx] = max_dist;
|
|
|
|
|
// Erase the row index.
|
|
|
|
|
row_pool.erase(
|
|
|
|
|
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
|
|
|
|
@ -87,34 +90,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* dis_mat = context.Input<LoDTensor>("DisMat");
|
|
|
|
|
auto* dist_mat = context.Input<LoDTensor>("DisMat");
|
|
|
|
|
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
|
|
|
|
|
auto* match_dis = context.Output<Tensor>("ColToRowMatchDis");
|
|
|
|
|
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
|
|
|
|
|
|
|
|
|
|
auto col = dis_mat->dims()[1];
|
|
|
|
|
auto col = dist_mat->dims()[1];
|
|
|
|
|
|
|
|
|
|
int64_t n = dis_mat->lod().size() == 0
|
|
|
|
|
int64_t n = dist_mat->lod().size() == 0UL
|
|
|
|
|
? 1
|
|
|
|
|
: static_cast<int64_t>(dis_mat->lod().back().size() - 1);
|
|
|
|
|
: static_cast<int64_t>(dist_mat->lod().back().size() - 1);
|
|
|
|
|
if (dist_mat->lod().size()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist_mat->lod().size(), 1UL,
|
|
|
|
|
"Only support 1 level of LoD.");
|
|
|
|
|
}
|
|
|
|
|
match_indices->mutable_data<int>({n, col}, context.GetPlace());
|
|
|
|
|
match_dis->mutable_data<T>({n, col}, context.GetPlace());
|
|
|
|
|
match_dist->mutable_data<T>({n, col}, context.GetPlace());
|
|
|
|
|
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, int> iset;
|
|
|
|
|
iset(dev_ctx, match_indices, static_cast<int>(-1));
|
|
|
|
|
math::SetConstant<platform::CPUDeviceContext, T> tset;
|
|
|
|
|
tset(dev_ctx, match_dis, static_cast<T>(0));
|
|
|
|
|
tset(dev_ctx, match_dist, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
int* indices = match_indices->data<int>();
|
|
|
|
|
T* dis = match_dis->data<T>();
|
|
|
|
|
T* dist = match_dist->data<T>();
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
BipartiteMatch(*dis_mat, indices, dis);
|
|
|
|
|
BipartiteMatch(*dist_mat, indices, dist);
|
|
|
|
|
} else {
|
|
|
|
|
auto lod = dis_mat->lod().back();
|
|
|
|
|
auto lod = dist_mat->lod().back();
|
|
|
|
|
for (size_t i = 0; i < lod.size() - 1; ++i) {
|
|
|
|
|
Tensor one_ins = dis_mat->Slice(lod[i], lod[i + 1]);
|
|
|
|
|
BipartiteMatch(one_ins, indices + i * col, dis + i * col);
|
|
|
|
|
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
|
|
|
|
|
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -131,7 +138,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"represented by each row and each column. For example, assumed one "
|
|
|
|
|
"entity is A with shape [K], another entity is B with shape [M]. The "
|
|
|
|
|
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
|
|
|
|
|
"the distance is, the more similar the pairs are. Please note, "
|
|
|
|
|
"the distance is, the better macthing the pairs are. Please note, "
|
|
|
|
|
"This tensor can contain LoD information to represent a batch of "
|
|
|
|
|
"inputs. One instance of this batch can contain different numbers of "
|
|
|
|
|
"entities.");
|
|
|
|
@ -140,20 +147,25 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
|
|
|
|
|
"means B[j] does not match any entity in i-th instance. "
|
|
|
|
|
"Otherwise, it means B[j] is matched to row "
|
|
|
|
|
"RowToColMatchIndices[i][j] in i-th instance. The row number of "
|
|
|
|
|
"i-th instance is saved in RowToColMatchIndices[i][j].");
|
|
|
|
|
"ColToRowMatchIndices[i][j] in i-th instance. The row number of "
|
|
|
|
|
"i-th instance is saved in ColToRowMatchIndices[i][j].");
|
|
|
|
|
AddOutput("ColToRowMatchDis",
|
|
|
|
|
"(Tensor) A 2-D Tensor with shape [N, M] in float type. "
|
|
|
|
|
"N is batch size. If ColToRowMatchIndices[i][j] is -1, "
|
|
|
|
|
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
|
|
|
|
|
"RowToColMatchIndices[i][j] = d, and the row offsets of each "
|
|
|
|
|
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
|
|
|
|
|
"instance are called LoD. Then "
|
|
|
|
|
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
This operator is a greedy bipartite matching algorithm, which is used to
|
|
|
|
|
obtain the matching with the (greedy) maximum distance based on the input
|
|
|
|
|
distance matrix. There are two outputs to save matched indices and distance.
|
|
|
|
|
And this operator only calculate matched indices from column to row.
|
|
|
|
|
obtain the matching with the maximum distance based on the input
|
|
|
|
|
distance matrix. For input 2D matrix, the bipartite matching algorithm can
|
|
|
|
|
find the matched column for each row, also can find the matched row for
|
|
|
|
|
each column. And this operator only calculate matched indices from column
|
|
|
|
|
to row. For each instance, the number of matched indices is the number of
|
|
|
|
|
of columns of the input ditance matrix.
|
|
|
|
|
|
|
|
|
|
There are two outputs to save matched indices and distance.
|
|
|
|
|
A simple description, this algothrim matched the best (maximum distance)
|
|
|
|
|
row entity to the column entity and the matched indices are not duplicated
|
|
|
|
|
in each row of ColToRowMatchIndices. If the column entity is not matched
|
|
|
|
|