|
|
|
@ -51,6 +51,12 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
bool DistPairDescend(std::tuple<int, int, T> pair1,
|
|
|
|
|
std::tuple<int, int, T> pair2) {
|
|
|
|
|
return std::get<2>(pair1) > std::get<2>(pair2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -58,11 +64,40 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
// The match_dist must be initialized to 0 at first.
|
|
|
|
|
void BipartiteMatch(const Tensor& dist, int* match_indices,
|
|
|
|
|
T* match_dist) const {
|
|
|
|
|
constexpr T kEPS = static_cast<T>(1e-6);
|
|
|
|
|
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
|
|
|
|
|
int64_t row = dist.dims()[0];
|
|
|
|
|
int64_t col = dist.dims()[1];
|
|
|
|
|
auto* dist_data = dist.data<T>();
|
|
|
|
|
// Test result: When row==130 the speed of these two methods almost the same
|
|
|
|
|
if (row >= 130) {
|
|
|
|
|
std::vector<std::tuple<int, int, T>> match_pair;
|
|
|
|
|
|
|
|
|
|
for (int64_t i = 0; i < row; ++i) {
|
|
|
|
|
for (int64_t j = 0; j < col; ++j) {
|
|
|
|
|
match_pair.push_back(std::make_tuple(i, j, dist_data[i * col + j]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::sort(match_pair.begin(), match_pair.end(), DistPairDescend<T>);
|
|
|
|
|
std::vector<int> row_indices(row, -1);
|
|
|
|
|
|
|
|
|
|
int64_t idx = 0;
|
|
|
|
|
for (int64_t k = 0; k < row * col; ++k) {
|
|
|
|
|
int64_t i = std::get<0>(match_pair[k]);
|
|
|
|
|
int64_t j = std::get<1>(match_pair[k]);
|
|
|
|
|
T dist = std::get<2>(match_pair[k]);
|
|
|
|
|
|
|
|
|
|
if (idx >= row) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (match_indices[j] == -1 && row_indices[i] == -1 && dist > 0) {
|
|
|
|
|
match_indices[j] = i;
|
|
|
|
|
row_indices[i] = j;
|
|
|
|
|
match_dist[j] = dist;
|
|
|
|
|
idx += 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
constexpr T kEPS = static_cast<T>(1e-6);
|
|
|
|
|
std::vector<int> row_pool;
|
|
|
|
|
for (int i = 0; i < row; ++i) {
|
|
|
|
|
row_pool.push_back(i);
|
|
|
|
@ -101,6 +136,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist,
|
|
|
|
|
T overlap_threshold) const {
|
|
|
|
|