Merge pull request #7695 from qingqing01/bipartite_match_op
Add bipartite matching operator and unit testing.fix-profile-doc-typo
commit
2b19a68cc9
@ -0,0 +1,190 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
constexpr char kEPS = 1e-6;
|
||||
|
||||
class BipartiteMatchOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("DistMat"),
|
||||
"Input(DistMat) of BipartiteMatch should not be null.");
|
||||
|
||||
auto dims = ctx->GetInputDim("DistMat");
|
||||
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
|
||||
|
||||
ctx->SetOutputDim("ColToRowMatchIndices", dims);
|
||||
ctx->SetOutputDim("ColToRowMatchDis", dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BipartiteMatchKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
// The match_indices must be initialized to -1 at first.
|
||||
// The match_dist must be initialized to 0 at first.
|
||||
void BipartiteMatch(const Tensor& dist, int* match_indices,
|
||||
T* match_dist) const {
|
||||
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
|
||||
int64_t row = dist.dims()[0];
|
||||
int64_t col = dist.dims()[1];
|
||||
auto* dist_data = dist.data<T>();
|
||||
std::vector<int> row_pool;
|
||||
for (int i = 0; i < row; ++i) {
|
||||
row_pool.push_back(i);
|
||||
}
|
||||
while (row_pool.size() > 0) {
|
||||
int max_idx = -1;
|
||||
int max_row_idx = -1;
|
||||
T max_dist = -1;
|
||||
for (int64_t j = 0; j < col; ++j) {
|
||||
if (match_indices[j] != -1) {
|
||||
continue;
|
||||
}
|
||||
for (size_t k = 0; k < row_pool.size(); ++k) {
|
||||
int m = row_pool[k];
|
||||
// distance is 0 between m-th row and j-th column
|
||||
if (dist_data[m * col + j] < kEPS) {
|
||||
continue;
|
||||
}
|
||||
if (dist_data[m * col + j] > max_dist) {
|
||||
max_idx = j;
|
||||
max_row_idx = m;
|
||||
max_dist = dist_data[m * col + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_idx == -1) {
|
||||
// Cannot find good match.
|
||||
break;
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
|
||||
match_indices[max_idx] = max_row_idx;
|
||||
match_dist[max_idx] = max_dist;
|
||||
// Erase the row index.
|
||||
row_pool.erase(
|
||||
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* dist_mat = context.Input<LoDTensor>("DistMat");
|
||||
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
|
||||
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
|
||||
|
||||
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
|
||||
|
||||
auto col = dist_mat->dims()[1];
|
||||
|
||||
int64_t n = dist_mat->lod().size() == 0UL
|
||||
? 1
|
||||
: static_cast<int64_t>(dist_mat->lod().back().size() - 1);
|
||||
if (dist_mat->lod().size()) {
|
||||
PADDLE_ENFORCE_EQ(dist_mat->lod().size(), 1UL,
|
||||
"Only support 1 level of LoD.");
|
||||
}
|
||||
match_indices->mutable_data<int>({n, col}, context.GetPlace());
|
||||
match_dist->mutable_data<T>({n, col}, context.GetPlace());
|
||||
|
||||
math::SetConstant<platform::CPUDeviceContext, int> iset;
|
||||
iset(dev_ctx, match_indices, static_cast<int>(-1));
|
||||
math::SetConstant<platform::CPUDeviceContext, T> tset;
|
||||
tset(dev_ctx, match_dist, static_cast<T>(0));
|
||||
|
||||
int* indices = match_indices->data<int>();
|
||||
T* dist = match_dist->data<T>();
|
||||
if (n == 1) {
|
||||
BipartiteMatch(*dist_mat, indices, dist);
|
||||
} else {
|
||||
auto lod = dist_mat->lod().back();
|
||||
for (size_t i = 0; i < lod.size() - 1; ++i) {
|
||||
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
|
||||
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"DistMat",
|
||||
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
|
||||
"[K, M]. It is pair-wise distance matrix between the entities "
|
||||
"represented by each row and each column. For example, assumed one "
|
||||
"entity is A with shape [K], another entity is B with shape [M]. The "
|
||||
"DistMat[i][j] is the distance between A[i] and B[j]. The bigger "
|
||||
"the distance is, the better macthing the pairs are. Please note, "
|
||||
"This tensor can contain LoD information to represent a batch of "
|
||||
"inputs. One instance of this batch can contain different numbers of "
|
||||
"entities.");
|
||||
AddOutput("ColToRowMatchIndices",
|
||||
"(Tensor) A 2-D Tensor with shape [N, M] in int type. "
|
||||
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
|
||||
"means B[j] does not match any entity in i-th instance. "
|
||||
"Otherwise, it means B[j] is matched to row "
|
||||
"ColToRowMatchIndices[i][j] in i-th instance. The row number of "
|
||||
"i-th instance is saved in ColToRowMatchIndices[i][j].");
|
||||
AddOutput("ColToRowMatchDis",
|
||||
"(Tensor) A 2-D Tensor with shape [N, M] in float type. "
|
||||
"N is batch size. If ColToRowMatchIndices[i][j] is -1, "
|
||||
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
|
||||
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
|
||||
"instance are called LoD. Then "
|
||||
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]");
|
||||
AddComment(R"DOC(
|
||||
This operator is a greedy bipartite matching algorithm, which is used to
|
||||
obtain the matching with the maximum distance based on the input
|
||||
distance matrix. For input 2D matrix, the bipartite matching algorithm can
|
||||
find the matched column for each row, also can find the matched row for
|
||||
each column. And this operator only calculate matched indices from column
|
||||
to row. For each instance, the number of matched indices is the number of
|
||||
of columns of the input ditance matrix.
|
||||
|
||||
There are two outputs to save matched indices and distance.
|
||||
A simple description, this algothrim matched the best (maximum distance)
|
||||
row entity to the column entity and the matched indices are not duplicated
|
||||
in each row of ColToRowMatchIndices. If the column entity is not matched
|
||||
any row entity, set -1 in ColToRowMatchIndices.
|
||||
|
||||
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
|
||||
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
|
||||
If Tensor, the height of ColToRowMatchIndices is 1.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(bipartite_match, ops::BipartiteMatchOp,
|
||||
ops::BipartiteMatchOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(bipartite_match, ops::BipartiteMatchKernel<float>,
|
||||
ops::BipartiteMatchKernel<double>);
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def bipartite_match(distance, match_indices, match_dis):
|
||||
"""Bipartite Matching algorithm.
|
||||
Arg:
|
||||
distance (numpy.array) : The distance of two entries with shape [M, N].
|
||||
match_indices (numpy.array): the matched indices from column to row
|
||||
with shape [1, N], it must be initialized to -1.
|
||||
match_dis (numpy.array): The matched distance from column to row
|
||||
with shape [1, N], it must be initialized to 0.
|
||||
"""
|
||||
match_pair = []
|
||||
row, col = distance.shape
|
||||
for i in range(row):
|
||||
for j in range(col):
|
||||
match_pair.append((i, j, distance[i][j]))
|
||||
|
||||
match_sorted = sorted(match_pair, key=lambda tup: tup[2], reverse=True)
|
||||
|
||||
row_indices = -1 * np.ones((row, ), dtype=np.int)
|
||||
|
||||
idx = 0
|
||||
for i, j, dis in match_sorted:
|
||||
if idx >= row:
|
||||
break
|
||||
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
|
||||
match_indices[j] = i
|
||||
row_indices[i] = j
|
||||
match_dis[j] = dis
|
||||
idx += 1
|
||||
|
||||
|
||||
def batch_bipartite_match(distance, lod):
|
||||
"""Bipartite Matching algorithm for batch input.
|
||||
Arg:
|
||||
distance (numpy.array) : The distance of two entries with shape [M, N].
|
||||
lod (list of int): The offsets of each input in this batch.
|
||||
"""
|
||||
n = len(lod) - 1
|
||||
m = distance.shape[1]
|
||||
match_indices = -1 * np.ones((n, m), dtype=np.int)
|
||||
match_dis = np.zeros((n, m), dtype=np.float32)
|
||||
for i in range(len(lod) - 1):
|
||||
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
|
||||
match_dis[i, :])
|
||||
return match_indices, match_dis
|
||||
|
||||
|
||||
class TestBipartiteMatchOpForWithLoD(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'bipartite_match'
|
||||
lod = [[0, 5, 11, 23]]
|
||||
dis = np.random.random((23, 217)).astype('float32')
|
||||
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
|
||||
|
||||
self.inputs = {'DistMat': (dis, lod)}
|
||||
self.outputs = {
|
||||
'ColToRowMatchIndices': (match_indices),
|
||||
'ColToRowMatchDis': (match_dis),
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestBipartiteMatchOpWithoutLoD(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = 'bipartite_match'
|
||||
lod = [[0, 8]]
|
||||
dis = np.random.random((8, 17)).astype('float32')
|
||||
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
|
||||
|
||||
self.inputs = {'DistMat': dis}
|
||||
self.outputs = {
|
||||
'ColToRowMatchIndices': (match_indices),
|
||||
'ColToRowMatchDis': (match_dis),
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue