|
|
|
@ -28,11 +28,11 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("DisMat"),
|
|
|
|
|
"Input(DisMat) of BipartiteMatch should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("DistMat"),
|
|
|
|
|
"Input(DistMat) of BipartiteMatch should not be null.");
|
|
|
|
|
|
|
|
|
|
auto dims = ctx->GetInputDim("DisMat");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DisMat) must be 2.");
|
|
|
|
|
auto dims = ctx->GetInputDim("DistMat");
|
|
|
|
|
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("ColToRowMatchIndices", dims);
|
|
|
|
|
ctx->SetOutputDim("ColToRowMatchDis", dims);
|
|
|
|
@ -90,7 +90,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto* dist_mat = context.Input<LoDTensor>("DisMat");
|
|
|
|
|
auto* dist_mat = context.Input<LoDTensor>("DistMat");
|
|
|
|
|
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
|
|
|
|
|
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
|
|
|
|
|
|
|
|
|
@ -132,12 +132,12 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(
|
|
|
|
|
"DisMat",
|
|
|
|
|
"DistMat",
|
|
|
|
|
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
|
|
|
|
|
"[K, M]. It is pair-wise distance matrix between the entities "
|
|
|
|
|
"represented by each row and each column. For example, assumed one "
|
|
|
|
|
"entity is A with shape [K], another entity is B with shape [M]. The "
|
|
|
|
|
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
|
|
|
|
|
"DistMat[i][j] is the distance between A[i] and B[j]. The bigger "
|
|
|
|
|
"the distance is, the better macthing the pairs are. Please note, "
|
|
|
|
|
"This tensor can contain LoD information to represent a batch of "
|
|
|
|
|
"inputs. One instance of this batch can contain different numbers of "
|
|
|
|
@ -155,7 +155,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
|
|
|
|
|
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
|
|
|
|
|
"instance are called LoD. Then "
|
|
|
|
|
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]");
|
|
|
|
|
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
This operator is a greedy bipartite matching algorithm, which is used to
|
|
|
|
|
obtain the matching with the maximum distance based on the input
|
|
|
|
@ -171,7 +171,7 @@ row entity to the column entity and the matched indices are not duplicated
|
|
|
|
|
in each row of ColToRowMatchIndices. If the column entity is not matched
|
|
|
|
|
any row entity, set -1 in ColToRowMatchIndices.
|
|
|
|
|
|
|
|
|
|
Please note that the input DisMat can be LoDTensor (with LoD) or Tensor.
|
|
|
|
|
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
|
|
|
|
|
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
|
|
|
|
|
If Tensor, the height of ColToRowMatchIndices is 1.
|
|
|
|
|
|
|
|
|
|