|
|
|
@ -65,7 +65,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
|
|
|
|
|
"""Bipartite Matching algorithm for batch input.
|
|
|
|
|
Arg:
|
|
|
|
|
distance (numpy.array) : The distance of two entries with shape [M, N].
|
|
|
|
|
lod (list of int): The offsets of each input in this batch.
|
|
|
|
|
lod (list of int): The length of each input in this batch.
|
|
|
|
|
"""
|
|
|
|
|
n = len(lod)
|
|
|
|
|
m = distance.shape[1]
|
|
|
|
@ -73,6 +73,7 @@ def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
|
|
|
|
|
match_dist = np.zeros((n, m), dtype=np.float32)
|
|
|
|
|
cur_offset = 0
|
|
|
|
|
for i in range(n):
|
|
|
|
|
if lod[i] == 0: continue
|
|
|
|
|
bipartite_match(distance[cur_offset:(cur_offset + lod[i]), :],
|
|
|
|
|
match_indices[i, :], match_dist[i, :])
|
|
|
|
|
if match_type == 'per_prediction':
|
|
|
|
@ -155,5 +156,22 @@ class TestBipartiteMatchOpWithPerPredictionType(OpTest):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBipartiteMatchOpWithEmptyLoD(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = 'bipartite_match'
|
|
|
|
|
lod = [[5, 6, 0, 12]]
|
|
|
|
|
dist = np.random.random((23, 217)).astype('float32')
|
|
|
|
|
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
|
|
|
|
|
|
|
|
|
|
self.inputs = {'DistMat': (dist, lod)}
|
|
|
|
|
self.outputs = {
|
|
|
|
|
'ColToRowMatchIndices': match_indices,
|
|
|
|
|
'ColToRowMatchDist': match_dist,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|