|
|
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/detection/nms_util.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -59,6 +60,9 @@ class MatrixNMSOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
|
|
|
|
|
ctx->SetOutputDim("Index", {box_dims[1], 1});
|
|
|
|
|
if (ctx->HasOutput("RoisNum")) {
|
|
|
|
|
ctx->SetOutputDim("RoisNum", {-1});
|
|
|
|
|
}
|
|
|
|
|
if (!ctx->IsRuntime()) {
|
|
|
|
|
ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1));
|
|
|
|
|
ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1));
|
|
|
|
@ -259,8 +263,10 @@ class MatrixNMSKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::vector<size_t> offsets = {0};
|
|
|
|
|
std::vector<T> detections;
|
|
|
|
|
std::vector<int> indices;
|
|
|
|
|
std::vector<int> num_per_batch;
|
|
|
|
|
detections.reserve(out_dim * num_boxes * batch_size);
|
|
|
|
|
indices.reserve(num_boxes * batch_size);
|
|
|
|
|
num_per_batch.reserve(batch_size);
|
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
|
scores_slice = scores->Slice(i, i + 1);
|
|
|
|
|
scores_slice.Resize({score_dims[1], score_dims[2]});
|
|
|
|
@ -272,6 +278,7 @@ class MatrixNMSKernel : public framework::OpKernel<T> {
|
|
|
|
|
background_label, nms_top_k, keep_top_k, normalized, score_threshold,
|
|
|
|
|
post_threshold, use_gaussian, gaussian_sigma);
|
|
|
|
|
offsets.push_back(offsets.back() + num_out);
|
|
|
|
|
num_per_batch.emplace_back(num_out);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t num_kept = offsets.back();
|
|
|
|
@ -285,6 +292,12 @@ class MatrixNMSKernel : public framework::OpKernel<T> {
|
|
|
|
|
std::copy(indices.begin(), indices.end(), index->data<int>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx.HasOutput("RoisNum")) {
|
|
|
|
|
auto* rois_num = ctx.Output<Tensor>("RoisNum");
|
|
|
|
|
rois_num->mutable_data<int>({batch_size}, ctx.GetPlace());
|
|
|
|
|
std::copy(num_per_batch.begin(), num_per_batch.end(),
|
|
|
|
|
rois_num->data<int>());
|
|
|
|
|
}
|
|
|
|
|
framework::LoD lod;
|
|
|
|
|
lod.emplace_back(offsets);
|
|
|
|
|
outs->set_lod(lod);
|
|
|
|
@ -355,6 +368,8 @@ class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the "
|
|
|
|
|
"index of selected bbox. The index is the absolute index cross "
|
|
|
|
|
"batches.");
|
|
|
|
|
AddOutput("RoisNum", "(Tensor), Number of RoIs in each images.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
This operator does multi-class matrix non maximum suppression (NMS) on batched
|
|
|
|
|
boxes and scores.
|
|
|
|
@ -369,7 +384,9 @@ This operator support multi-class and batched inputs. It applying NMS
|
|
|
|
|
independently for each class. The outputs is a 2-D LoDTenosr, for each
|
|
|
|
|
image, the offsets in first dimension of LoDTensor are called LoD, the number
|
|
|
|
|
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
|
|
|
|
|
means there is no detected bbox for this image.
|
|
|
|
|
means there is no detected bbox for this image. Now this operator has one more
|
|
|
|
|
ouput, which is RoisNum. The size of RoisNum is N, RoisNum[i] means the number of
|
|
|
|
|
detected bbox for this image.
|
|
|
|
|
|
|
|
|
|
For more information on Matrix NMS, please refer to:
|
|
|
|
|
https://arxiv.org/abs/2003.10152
|
|
|
|
@ -387,3 +404,8 @@ REGISTER_OPERATOR(
|
|
|
|
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(matrix_nms, ops::MatrixNMSKernel<float>,
|
|
|
|
|
ops::MatrixNMSKernel<double>);
|
|
|
|
|
REGISTER_OP_VERSION(matrix_nms)
|
|
|
|
|
.AddCheckpoint(
|
|
|
|
|
R"ROC(Upgrade matrix_nms: add a new output [RoisNum].)ROC",
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc().NewOutput(
|
|
|
|
|
"RoisNum", "The number of RoIs in each image."));
|
|
|
|
|