|
|
|
@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist,
|
|
|
|
|
T overlap_threshold) const {
|
|
|
|
|
constexpr T kEPS = static_cast<T>(1e-6);
|
|
|
|
|
int64_t row = dist.dims()[0];
|
|
|
|
|
int64_t col = dist.dims()[1];
|
|
|
|
|
auto* dist_data = dist.data<T>();
|
|
|
|
|
for (int64_t j = 0; j < col; ++j) {
|
|
|
|
|
if (match_indices[j] != -1) {
|
|
|
|
|
// the j-th column has been matched to one entity.
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
int max_row_idx = -1;
|
|
|
|
|
T max_dist = -1;
|
|
|
|
|
for (int i = 0; i < row; ++i) {
|
|
|
|
|
T dist = dist_data[i * col + j];
|
|
|
|
|
if (dist < kEPS) {
|
|
|
|
|
// distance is 0 between m-th row and j-th column
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (dist >= overlap_threshold && dist > max_dist) {
|
|
|
|
|
max_row_idx = i;
|
|
|
|
|
max_dist = dist;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (max_row_idx != -1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(match_indices[j], -1);
|
|
|
|
|
match_indices[j] = max_row_idx;
|
|
|
|
|
match_dist[j] = max_dist;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* dist_mat = context.Input<LoDTensor>("DistMat");
|
|
|
|
|
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
|
|
|
|
@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
int* indices = match_indices->data<int>();
|
|
|
|
|
T* dist = match_dist->data<T>();
|
|
|
|
|
auto type = context.Attr<std::string>("match_type");
|
|
|
|
|
auto threshold = context.Attr<float>("dist_threshold");
|
|
|
|
|
if (n == 1) {
|
|
|
|
|
BipartiteMatch(*dist_mat, indices, dist);
|
|
|
|
|
if (type == "per_prediction") {
|
|
|
|
|
ArgMaxMatch(*dist_mat, indices, dist, threshold);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto lod = dist_mat->lod().back();
|
|
|
|
|
for (size_t i = 0; i < lod.size() - 1; ++i) {
|
|
|
|
|
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
|
|
|
|
|
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
|
|
|
|
|
if (type == "per_prediction") {
|
|
|
|
|
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"This tensor can contain LoD information to represent a batch of "
|
|
|
|
|
"inputs. One instance of this batch can contain different numbers of "
|
|
|
|
|
"entities.");
|
|
|
|
|
AddAttr<std::string>(
|
|
|
|
|
"match_type",
|
|
|
|
|
"(string, defalut: per_prediction) "
|
|
|
|
|
"The type of matching method, should be 'bipartite' or "
|
|
|
|
|
"'per_prediction', 'bipartite' by defalut.")
|
|
|
|
|
.SetDefault("bipartite")
|
|
|
|
|
.InEnum({"bipartite", "per_prediction"});
|
|
|
|
|
AddAttr<float>(
|
|
|
|
|
"dist_threshold",
|
|
|
|
|
"(float, defalut: 0.5) "
|
|
|
|
|
"If `match_type` is 'per_prediction', this threshold is to determine "
|
|
|
|
|
"the extra matching bboxes based on the maximum distance.")
|
|
|
|
|
.SetDefault(0.5);
|
|
|
|
|
AddOutput("ColToRowMatchIndices",
|
|
|
|
|
"(Tensor) A 2-D Tensor with shape [N, M] in int type. "
|
|
|
|
|
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
|
|
|
|
@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can
|
|
|
|
|
find the matched column for each row, also can find the matched row for
|
|
|
|
|
each column. And this operator only calculate matched indices from column
|
|
|
|
|
to row. For each instance, the number of matched indices is the number of
|
|
|
|
|
of columns of the input ditance matrix.
|
|
|
|
|
of columns of the input distance matrix.
|
|
|
|
|
|
|
|
|
|
There are two outputs to save matched indices and distance.
|
|
|
|
|
A simple description, this algothrim matched the best (maximum distance)
|
|
|
|
|
A simple description, this algorithm matched the best (maximum distance)
|
|
|
|
|
row entity to the column entity and the matched indices are not duplicated
|
|
|
|
|
in each row of ColToRowMatchIndices. If the column entity is not matched
|
|
|
|
|
any row entity, set -1 in ColToRowMatchIndices.
|
|
|
|
|