|
|
|
@ -51,6 +51,12 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
bool DistPairDescend(std::tuple<int, int, T> pair1,
|
|
|
|
|
std::tuple<int, int, T> pair2) {
|
|
|
|
|
return std::get<2>(pair1) > std::get<2>(pair2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -58,46 +64,76 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
// The match_dist must be initialized to 0 at first.
|
|
|
|
|
void BipartiteMatch(const Tensor& dist, int* match_indices,
|
|
|
|
|
T* match_dist) const {
|
|
|
|
|
constexpr T kEPS = static_cast<T>(1e-6);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
|
|
|
|
|
int64_t row = dist.dims()[0];
|
|
|
|
|
int64_t col = dist.dims()[1];
|
|
|
|
|
auto* dist_data = dist.data<T>();
|
|
|
|
|
std::vector<int> row_pool;
|
|
|
|
|
for (int i = 0; i < row; ++i) {
|
|
|
|
|
row_pool.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
while (row_pool.size() > 0) {
|
|
|
|
|
int max_idx = -1;
|
|
|
|
|
int max_row_idx = -1;
|
|
|
|
|
T max_dist = -1;
|
|
|
|
|
for (int64_t j = 0; j < col; ++j) {
|
|
|
|
|
if (match_indices[j] != -1) {
|
|
|
|
|
continue;
|
|
|
|
|
// Test result: When row==130 the speed of these two methods almost the same
|
|
|
|
|
if (row >= 130) {
|
|
|
|
|
std::vector<std::tuple<int, int, T>> match_pair;
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < row; ++i) {
|
|
|
|
|
for (int64_t j = 0; j < col; ++j) {
|
|
|
|
|
match_pair.push_back(std::make_tuple(i, j, dist_data[i * col + j]));
|
|
|
|
|
}
|
|
|
|
|
for (size_t k = 0; k < row_pool.size(); ++k) {
|
|
|
|
|
int m = row_pool[k];
|
|
|
|
|
// distance is 0 between m-th row and j-th column
|
|
|
|
|
if (dist_data[m * col + j] < kEPS) {
|
|
|
|
|
}
|
|
|
|
|
std::sort(match_pair.begin(), match_pair.end(), DistPairDescend<T>);
|
|
|
|
|
std::vector<int> row_indices(row, -1);
|
|
|
|
|
|
|
|
|
|
int64_t idx = 0;
|
|
|
|
|
for (int64_t k = 0; k < row * col; ++k) {
|
|
|
|
|
int64_t i = std::get<0>(match_pair[k]);
|
|
|
|
|
int64_t j = std::get<1>(match_pair[k]);
|
|
|
|
|
T dist = std::get<2>(match_pair[k]);
|
|
|
|
|
|
|
|
|
|
if (idx >= row) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (match_indices[j] == -1 && row_indices[i] == -1 && dist > 0) {
|
|
|
|
|
match_indices[j] = i;
|
|
|
|
|
row_indices[i] = j;
|
|
|
|
|
match_dist[j] = dist;
|
|
|
|
|
idx += 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
constexpr T kEPS = static_cast<T>(1e-6);
|
|
|
|
|
std::vector<int> row_pool;
|
|
|
|
|
for (int i = 0; i < row; ++i) {
|
|
|
|
|
row_pool.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
while (row_pool.size() > 0) {
|
|
|
|
|
int max_idx = -1;
|
|
|
|
|
int max_row_idx = -1;
|
|
|
|
|
T max_dist = -1;
|
|
|
|
|
for (int64_t j = 0; j < col; ++j) {
|
|
|
|
|
if (match_indices[j] != -1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (dist_data[m * col + j] > max_dist) {
|
|
|
|
|
max_idx = j;
|
|
|
|
|
max_row_idx = m;
|
|
|
|
|
max_dist = dist_data[m * col + j];
|
|
|
|
|
for (size_t k = 0; k < row_pool.size(); ++k) {
|
|
|
|
|
int m = row_pool[k];
|
|
|
|
|
// distance is 0 between m-th row and j-th column
|
|
|
|
|
if (dist_data[m * col + j] < kEPS) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (dist_data[m * col + j] > max_dist) {
|
|
|
|
|
max_idx = j;
|
|
|
|
|
max_row_idx = m;
|
|
|
|
|
max_dist = dist_data[m * col + j];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (max_idx == -1) {
|
|
|
|
|
// Cannot find good match.
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
|
|
|
|
|
match_indices[max_idx] = max_row_idx;
|
|
|
|
|
match_dist[max_idx] = max_dist;
|
|
|
|
|
// Erase the row index.
|
|
|
|
|
row_pool.erase(
|
|
|
|
|
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
|
|
|
|
|
if (max_idx == -1) {
|
|
|
|
|
// Cannot find good match.
|
|
|
|
|
break;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
|
|
|
|
|
match_indices[max_idx] = max_row_idx;
|
|
|
|
|
match_dist[max_idx] = max_dist;
|
|
|
|
|
// Erase the row index.
|
|
|
|
|
row_pool.erase(
|
|
|
|
|
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|