Support empty bbox in bipartite math op (#26488)

test_feature_precision_test_c
qingqing01 5 years ago committed by GitHub
parent 87843bebde
commit 24566e951c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -222,10 +222,12 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
} else {
auto lod = dist_mat->lod().back();
for (size_t i = 0; i < lod.size() - 1; ++i) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
if (lod[i + 1] > lod[i]) {
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
if (type == "per_prediction") {
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
}
}
}
}

@ -65,7 +65,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
"""Bipartite Matching algorithm for batch input.
Arg:
distance (numpy.array) : The distance of two entries with shape [M, N].
lod (list of int): The offsets of each input in this batch.
lod (list of int): The length of each input in this batch.
"""
n = len(lod)
m = distance.shape[1]
@ -73,6 +73,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
match_dist = np.zeros((n, m), dtype=np.float32)
cur_offset = 0
for i in range(n):
if lod[i] == 0: continue
bipartite_match(distance[cur_offset:(cur_offset + lod[i]), :],
match_indices[i, :], match_dist[i, :])
if match_type == 'per_prediction':
@ -155,5 +156,22 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest):
self.check_output()
class TestBipartiteMatchOpWithEmptyLoD(OpTest):
def setUp(self):
self.op_type = 'bipartite_match'
lod = [[5, 6, 0, 12]]
dist = np.random.random((23, 217)).astype('float32')
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
self.inputs = {'DistMat': (dist, lod)}
self.outputs = {
'ColToRowMatchIndices': match_indices,
'ColToRowMatchDist': match_dist,
}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save