From 912a4f2511ad118d7a989cbe4e7f634503670e34 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 29 Jan 2018 23:49:56 +0800 Subject: [PATCH 1/6] Add multi-class non-maximum suppression operator. --- paddle/operators/multiclass_nms_op.cc | 353 ++++++++++++++++++ .../v2/fluid/tests/test_bipartite_match_op.py | 2 +- .../v2/fluid/tests/test_multiclass_nms_op.py | 199 ++++++++++ 3 files changed, 553 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/multiclass_nms_op.cc create mode 100644 python/paddle/v2/fluid/tests/test_multiclass_nms_op.py diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc new file mode 100644 index 0000000000..19c5b7efd6 --- /dev/null +++ b/paddle/operators/multiclass_nms_op.cc @@ -0,0 +1,353 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +constexpr int64_t kOutputDim = 6; +constexpr int64_t kBBoxSize = 4; + +class MulticlassNMSOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Bboxes"), + "Input(Bboxes) of MulticlassNMS should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Scores"), + "Input(Scores) of MulticlassNMS should not be null."); + + auto box_dims = ctx->GetInputDim("Bboxes"); + auto score_dims = ctx->GetInputDim("Scores"); + + PADDLE_ENFORCE_EQ(box_dims.size(), 3, + "The rank of Input(Bboxes) must be 3."); + PADDLE_ENFORCE_EQ(score_dims.size(), 3, + "The rank of Input(Scores) must be 3."); + PADDLE_ENFORCE_EQ(box_dims[0], score_dims[0]); + PADDLE_ENFORCE_EQ(box_dims[2], 4); + PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2]); + + // Here the box_dims[0] is not the real dimension of output. + // It will be rewritten in the computing kernel. + ctx->SetOutputDim("Out", {box_dims[0], 6}); + } +}; + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, const T threshold, int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < sorted_indices->size()) { + sorted_indices->resize(top_k); + } +} + +template +T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. + return T(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If bbox is not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const T* box1, const T* box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + const T inter_w = inter_xmax - inter_xmin; + const T inter_h = inter_ymax - inter_ymin; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +class MulticlassNMSKernel : public framework::OpKernel { + public: + void NMSFast(const Tensor& bbox, const Tensor& scores, + const T score_threshold, const T nms_threshold, const T eta, + const int64_t top_k, std::vector* selected_indices) const { + // The total boxes for each instance. + int64_t num_boxes = bbox.dims()[0]; + // 4: [xmin ymin xmax ymax] + int64_t box_size = bbox.dims()[1]; + + std::vector scores_data(num_boxes); + std::copy_n(scores.data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices; + GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); + + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + const T* bbox_data = bbox.data(); + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (int k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = JaccardOverlap(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, true); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } + } + + void MulticlassNMS(const framework::ExecutionContext& ctx, + const Tensor& scores, const Tensor& bboxes, + std::map>* indices, + int* num_nmsed_out) const { + int64_t background_label = ctx.Attr("background_label"); + int64_t nms_top_k = ctx.Attr("nms_top_k"); + int64_t keep_top_k = ctx.Attr("keep_top_k"); + T nms_threshold = static_cast(ctx.Attr("nms_threshold")); + T nms_eta = static_cast(ctx.Attr("nms_eta")); + T score_threshold = static_cast(ctx.Attr("confidence_threshold")); + + int64_t class_num = scores.dims()[0]; + int64_t predict_dim = scores.dims()[1]; + int num_det = 0; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + Tensor score = scores.Slice(c, c + 1); + NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, + &((*indices)[c])); + num_det += indices[c].size(); + } + + *num_nmsed_out = num_det; + const T* scores_data = scores.data(); + if (keep_top_k > -1 && num_det > keep_top_k) { + std::vector>> score_index_pairs; + for (const auto& it : *indices) { + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + const std::vector& label_indices = it.second; + for (int j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + PADDLE_ENFORCE_LT(idx, predict_dim); + score_index_pairs.push_back( + std::make_pair(sdata[idx], std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(keep_top_k); + + // Store the new indices. + std::map> new_indices; + for (int j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + new_indices.swap(*indices); + *num_nmsed_out = keep_top_k; + } + } + + void MulticlassOutput(const Tensor& scores, const Tensor& bboxes, + std::map>& selected_indices, + Tensor* outs) const { + int predict_dim = scores.dims()[1]; + auto* scores_data = scores.data(); + auto* bboxes_data = bboxes.data(); + auto* odata = outs->data(); + + int count = 0; + for (const auto& it : selected_indices) { + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + std::vector indices = it.second; + for (int j = 0; j < indices.size(); ++j) { + int idx = indices[j]; + const T* bdata = bboxes_data + idx * kBBoxSize; + odata[count * kOutputDim] = label; // label + odata[count * kOutputDim + 1] = sdata[idx]; // score + odata[count * kOutputDim + 2] = bdata[0]; // xmin + odata[count * kOutputDim + 3] = bdata[1]; // ymin + odata[count * kOutputDim + 4] = bdata[2]; // xmax + odata[count * kOutputDim + 5] = bdata[3]; // ymax + } + count++; + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* boxes = ctx.Input("Bboxes"); + auto* scores = ctx.Input("Scores"); + auto* outs = ctx.Output("Out"); + + auto box_dims = boxes->dims(); + auto score_dims = scores->dims(); + + int64_t batch_size = box_dims[0]; + int64_t class_num = score_dims[1]; + int64_t predict_dim = score_dims[2]; + + std::vector>> all_indices; + std::vector batch_starts = {0}; + for (int64_t i = 0; i < batch_size; ++i) { + Tensor ins_score = scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + std::map> indices; + int num_nmsed_out = 0; + MulticlassNMS(ctx, ins_score, *boxes, &indices, &num_nmsed_out); + all_indices.push_back(indices); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + int num_kept = batch_starts.back(); + if (num_kept == 0) { + outs->Resize({0, 0}); + } else { + outs->mutable_data({num_kept, kOutputDim}, ctx.GetPlace()); + for (int64_t i = 0; i < batch_size; ++i) { + Tensor ins_score = scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + int64_t s = batch_starts[i]; + int64_t e = batch_starts[i + 1]; + if (e > s) { + Tensor out = outs->Slice(s, e); + MulticlassOutput(ins_score, *boxes, all_indices[i], &out); + } + } + } + + framework::LoD lod; + lod.emplace_back(batch_starts); + + outs->set_lod(lod); + } +}; + +class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MulticlassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Bboxes", + "(Tensor) A 2-D Tensor with shape [M, 4] represents the location " + "predictions with M bboxes. 4 is the number of " + "each location coordinates."); + AddOutput("Scores", + "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " + "confidence predictions. N is the batch size, C is the class " + "number, M is number of predictions for each class, which is " + "the same with Bboxes."); + AddAttr( + "background_label", + "(int64_t, defalut: 0) " + "The index of background label, the background label will be ignored.") + .SetDefault(0); + AddAttr("nms_threshold", + "(float, defalut: 0.3) " + "The threshold to be used in nms.") + .SetDefault(0.3); + AddAttr("nms_top_k", + "(int64_t) " + " ."); + AddAttr("nms_eta", + "(float) " + "The parameter for adaptive nms.") + .SetDefault(1.0); + AddAttr("keep_top_k", + "(int64_t) " + "."); + AddAttr("confidence_threshold", + "(float) " + "."); + AddOutput("Out", + "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], No is the total " + "number of detections in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected bbox."); + AddComment(R"DOC( +This operators is to do multi-class non maximum suppression (nms) on a batched +of boxes and scores. + +This op greedily selects a subset of detection bounding boxes, pruning +away boxes that have high IOU (intersection over union) overlap (> thresh) +with already selected boxes. It operates independently for each class for +which scores are provided (via the scores field of the input box_list), +pruning boxes with score less than a provided threshold prior to +applying NMS. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(multiclass_nms, ops::MulticlassNMSOp, + ops::MulticlassNMSOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MulticlassNMSKernel, + ops::MulticlassNMSKernel); diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py index 7413829897..c35fb20b10 100644 --- a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py +++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py @@ -62,7 +62,7 @@ def batch_bipartite_match(distance, lod): return match_indices, match_dist -class TestBipartiteMatchOpForWithLoD(OpTest): +class TestBipartiteMatchOpWithLoD(OpTest): def setUp(self): self.op_type = 'bipartite_match' lod = [[0, 5, 11, 23]] diff --git a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py new file mode 100644 index 0000000000..60c6488f84 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py @@ -0,0 +1,199 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +import copy +from op_test import OpTest + + +def iou(box_a, box_b): + """Apply intersection-over-union overlap between box_a and box_b + """ + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + + area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a) + area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) + + box_a_area = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]) + box_b_area = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + + return iou_ratio + + +def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + all_scores = copy.deepcopy(scores) + all_scores = all_scores.flatten() + selected_indices = np.argwhere(all_scores > score_threshold) + selected_indices = selected_indices.flatten() + all_scores = all_scores[selected_indices] + + sorted_indices = np.argsort(-all_scores, axis=0) + sorted_scores = all_scores[sorted_indices] + if top_k < -1 and top_k < sorted_indices.shape[0]: + sorted_indices = sorted_indices[:top_k] + sorted_scores = sorted_scores[:top_k] + + selected_indices = [] + adaptive_threshold = nms_threshold + for i in range(sorted_scores.shape[0]): + idx = sorted_indices[i] + keep = True + for k in range(len(selected_indices)): + if keep: + kept_idx = selected_indices[k] + overlap = iou(boxes[idx], boxes[kept_idx]) + keep = overlap <= adaptive_threshold + else: + break + if keep: + selected_indices.append(idx) + if keep and eta < 1 and adaptive_threshold > 0.5: + adaptive_threshold *= eta + return selected_indices + + +def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, + nms_top_k, keep_top_k): + class_num = scores.shape[0] + priorbox_num = scores.shape[1] + + selected_indices = [] + num_det = 0 + for c in range(class_num): + if c == background: continue + indices = nms(boxes, scores[c], score_threshold, nms_threshold, + nms_top_k) + selected_indices.append((c, indices)) + num_det += len(indices) + + if keep_top_k > -1 and num_det > keep_top_k: + score_index = [] + for c, indices in selected_indices: + for idx in indices: + score_index.append((scores[c][idx], c, idx)) + + sorted_score_index = sorted( + score_index, key=lambda tup: tup[0], reverse=True) + sorted_score_index = sorted_score_index[:keep_top_k] + selected_indices = [] + for s, c, idx in sorted_score_index: + selected_indices.append((c, idx)) + + return selected_indices + + +def batched_multiclass_nms(boxes, scores, background, score_threshold, + nms_threshold, nms_top_k, keep_top_k): + batch_size = scores.shape[0] + + det_outs = [] + lod = [0] + for n in range(batch_size): + nmsed_outs = multiclass_nms(boxes, scores[n], background, + score_threshold, nms_threshold, nms_top_k, + keep_top_k) + lod.append(lod[-1] + len(nmsed_outs)) + if len(nmsed_outs) == 0: continue + for c, indices in nmsed_outs: + for idx in indices: + xmin, ymin, xmax, ymax = boxes[idx][:] + det_outs.append( + (c, scores[n][c][idx], c, xmin, ymin, xmax, ymax)) + return det_outs, lod + + +class TestMulticlassNMSOp(OpTest): + def setUp(self): + self.op_type = 'multiclass_nms' + N = 7 + M = 1230 + C = 21 + BOX_SIZE = 4 + background = 0 + nms_threshold = 0.3 + nms_top_k = 400 + keep_top_k = 200 + score_threshold = 0.01 + + scores = np.random.random((N, C, M)).astype('float32') + boxes = np.random.random((M, BOX_SIZE)).astype('float32') + boxes[:, 0:2] = boxes[:, 0:2] * 0.5 + boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 + + nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, + score_threshold, nms_threshold, + nms_top_k, keep_top_k) + self.inputs = {'Bboxes': boxes, 'Scores': scores} + self.outputs = {'Out': (nmsed_outs, [lod])} + + def test_check_output(self): + self.check_output() + + +class TestIOU(unittest.TestCase): + def test_iou(self): + box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32') + box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32') + + expt_output = np.array([2.0 / 16.0]).astype('float32') + calc_output = np.array([iou(box1, box2)]).astype('float32') + self.assertTrue(np.allclose(calc_output, expt_output)) + + +if __name__ == '__main__': + unittest.main() + # N = 7 + # M = 8 + # C = 5 + # BOX_SIZE = 4 + # background = 0 + # nms_threshold = 0.3 + # nms_top_k = 400 + # keep_top_k = 200 + # score_threshold = 0.5 + + # scores = np.random.random((N, C, M)).astype('float32') + # boxes = np.random.random((M, BOX_SIZE)).astype('float32') + # boxes[:, 0 : 2] = boxes[:, 0 : 2] * 0.5 + # boxes[:, 2 : 4] = boxes[:, 0 : 2] * 0.5 + 0.5 + # print nmsed_outs, lod From 2731fd96606b18411b485269e36fd44ae8909650 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 30 Jan 2018 00:19:28 +0800 Subject: [PATCH 2/6] Update doc for multiclass_nms_op. --- paddle/operators/multiclass_nms_op.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index 19c5b7efd6..5da553a6cc 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -37,13 +37,12 @@ class MulticlassNMSOp : public framework::OperatorWithKernel { auto box_dims = ctx->GetInputDim("Bboxes"); auto score_dims = ctx->GetInputDim("Scores"); - PADDLE_ENFORCE_EQ(box_dims.size(), 3, + PADDLE_ENFORCE_EQ(box_dims.size(), 2, "The rank of Input(Bboxes) must be 3."); PADDLE_ENFORCE_EQ(score_dims.size(), 3, "The rank of Input(Scores) must be 3."); - PADDLE_ENFORCE_EQ(box_dims[0], score_dims[0]); PADDLE_ENFORCE_EQ(box_dims[2], 4); - PADDLE_ENFORCE_EQ(box_dims[1], score_dims[2]); + PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. @@ -308,17 +307,19 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(0.3); AddAttr("nms_top_k", "(int64_t) " - " ."); + "Maximum number of results to be kept."); AddAttr("nms_eta", "(float) " "The parameter for adaptive nms.") .SetDefault(1.0); AddAttr("keep_top_k", "(int64_t) " - "."); + "Number of total bboxes to be kept per image after nms " + "step. -1 means keeping all bboxes after nms step."); AddAttr("confidence_threshold", "(float) " - "."); + "Only consider detections whose confidences are larger than " + "a threshold. If not provided, consider all boxes."); AddOutput("Out", "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " "detections. Each row has 6 values: " @@ -328,15 +329,14 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " "no detected bbox."); AddComment(R"DOC( -This operators is to do multi-class non maximum suppression (nms) on a batched +This operators is to do multi-class non maximum suppression (NMS) on a batched of boxes and scores. This op greedily selects a subset of detection bounding boxes, pruning away boxes that have high IOU (intersection over union) overlap (> thresh) with already selected boxes. It operates independently for each class for -which scores are provided (via the scores field of the input box_list), -pruning boxes with score less than a provided threshold prior to -applying NMS. +which scores are provided, pruning boxes with score less than a provided +threshold prior to applying NMS. )DOC"); } From 35dec3d7228e2f924ccc6549a420604110640337 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 30 Jan 2018 17:59:48 +0800 Subject: [PATCH 3/6] Fix bug in unit test. --- paddle/operators/multiclass_nms_op.cc | 84 +++++++++++-------- .../v2/fluid/tests/test_multiclass_nms_op.py | 61 +++++++------- 2 files changed, 82 insertions(+), 63 deletions(-) diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index 5da553a6cc..93c8b5216f 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -41,13 +41,22 @@ class MulticlassNMSOp : public framework::OperatorWithKernel { "The rank of Input(Bboxes) must be 3."); PADDLE_ENFORCE_EQ(score_dims.size(), 3, "The rank of Input(Scores) must be 3."); - PADDLE_ENFORCE_EQ(box_dims[2], 4); + PADDLE_ENFORCE_EQ(box_dims[1], 4); PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. ctx->SetOutputDim("Out", {box_dims[0], 6}); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("Scores")->type()), + ctx.device_context()); + } }; template @@ -158,12 +167,12 @@ class MulticlassNMSKernel : public framework::OpKernel { const Tensor& scores, const Tensor& bboxes, std::map>* indices, int* num_nmsed_out) const { - int64_t background_label = ctx.Attr("background_label"); - int64_t nms_top_k = ctx.Attr("nms_top_k"); - int64_t keep_top_k = ctx.Attr("keep_top_k"); + int64_t background_label = ctx.Attr("background_label"); + int64_t nms_top_k = ctx.Attr("nms_top_k"); + int64_t keep_top_k = ctx.Attr("keep_top_k"); T nms_threshold = static_cast(ctx.Attr("nms_threshold")); T nms_eta = static_cast(ctx.Attr("nms_eta")); - T score_threshold = static_cast(ctx.Attr("confidence_threshold")); + T score_threshold = static_cast(ctx.Attr("score_threshold")); int64_t class_num = scores.dims()[0]; int64_t predict_dim = scores.dims()[1]; @@ -173,7 +182,7 @@ class MulticlassNMSKernel : public framework::OpKernel { Tensor score = scores.Slice(c, c + 1); NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, &((*indices)[c])); - num_det += indices[c].size(); + num_det += (*indices)[c].size(); } *num_nmsed_out = num_det; @@ -230,8 +239,8 @@ class MulticlassNMSKernel : public framework::OpKernel { odata[count * kOutputDim + 3] = bdata[1]; // ymin odata[count * kOutputDim + 4] = bdata[2]; // xmax odata[count * kOutputDim + 5] = bdata[3]; // ymax + count++; } - count++; } } @@ -240,10 +249,9 @@ class MulticlassNMSKernel : public framework::OpKernel { auto* scores = ctx.Input("Scores"); auto* outs = ctx.Output("Out"); - auto box_dims = boxes->dims(); auto score_dims = scores->dims(); - int64_t batch_size = box_dims[0]; + int64_t batch_size = score_dims[0]; int64_t class_num = score_dims[1]; int64_t predict_dim = score_dims[2]; @@ -291,35 +299,37 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) A 2-D Tensor with shape [M, 4] represents the location " "predictions with M bboxes. 4 is the number of " "each location coordinates."); - AddOutput("Scores", - "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " - "confidence predictions. N is the batch size, C is the class " - "number, M is number of predictions for each class, which is " - "the same with Bboxes."); - AddAttr( + AddInput("Scores", + "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " + "confidence predictions. N is the batch size, C is the class " + "number, M is number of predictions for each class, which is " + "the same with Bboxes."); + AddAttr( "background_label", "(int64_t, defalut: 0) " "The index of background label, the background label will be ignored.") .SetDefault(0); + AddAttr("score_threshold", + "(float) " + "Only consider detections whose confidences are larger than " + "a threshold. If not provided, consider all boxes."); + AddAttr("nms_top_k", + "(int64_t) " + "Maximum number of detections to be kept according to the " + "confidences aftern the filtering detections based on " + "score_threshold"); AddAttr("nms_threshold", "(float, defalut: 0.3) " - "The threshold to be used in nms.") + "The threshold to be used in NMS.") .SetDefault(0.3); - AddAttr("nms_top_k", - "(int64_t) " - "Maximum number of results to be kept."); AddAttr("nms_eta", "(float) " - "The parameter for adaptive nms.") + "The parameter for adaptive NMS.") .SetDefault(1.0); - AddAttr("keep_top_k", - "(int64_t) " - "Number of total bboxes to be kept per image after nms " - "step. -1 means keeping all bboxes after nms step."); - AddAttr("confidence_threshold", - "(float) " - "Only consider detections whose confidences are larger than " - "a threshold. If not provided, consider all boxes."); + AddAttr("keep_top_k", + "(int64_t) " + "Number of total bboxes to be kept per image after NMS " + "step. -1 means keeping all bboxes after NMS step."); AddOutput("Out", "(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the " "detections. Each row has 6 values: " @@ -329,15 +339,21 @@ class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " "no detected bbox."); AddComment(R"DOC( -This operators is to do multi-class non maximum suppression (NMS) on a batched +This operator is to do multi-class non maximum suppression (NMS) on a batched of boxes and scores. -This op greedily selects a subset of detection bounding boxes, pruning -away boxes that have high IOU (intersection over union) overlap (> thresh) -with already selected boxes. It operates independently for each class for -which scores are provided, pruning boxes with score less than a provided -threshold prior to applying NMS. +In the NMS step, this operator greedily selects a subset of detection bounding +boxes that have high scores larger than score_threshold, if providing this +threshold, then selects the largest nms_top_k confidences scores if nms_top_k +is larger than -1. Then this operator pruns away boxes that have high IOU +(intersection over union) overlap with already selected boxes by adaptive +threshold NMS based on parameters of nms_threshold and nms_eta. + +Aftern NMS step, only at most keep_top_k number of total bboxes are to be kept +per image if keep_top_k is larger than -1. +This operator support multi-class and batched inputs. It applying NMS +independently for each class. )DOC"); } }; diff --git a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py index 60c6488f84..b619c52e55 100644 --- a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py +++ b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py @@ -69,7 +69,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): sorted_indices = np.argsort(-all_scores, axis=0) sorted_scores = all_scores[sorted_indices] - if top_k < -1 and top_k < sorted_indices.shape[0]: + if top_k > -1 and top_k < sorted_indices.shape[0]: sorted_indices = sorted_indices[:top_k] sorted_scores = sorted_scores[:top_k] @@ -82,7 +82,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): if keep: kept_idx = selected_indices[k] overlap = iou(boxes[idx], boxes[kept_idx]) - keep = overlap <= adaptive_threshold + keep = True if overlap <= adaptive_threshold else False else: break if keep: @@ -103,14 +103,14 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, if c == background: continue indices = nms(boxes, scores[c], score_threshold, nms_threshold, nms_top_k) - selected_indices.append((c, indices)) + for idx in indices: + selected_indices.append((c, idx)) num_det += len(indices) if keep_top_k > -1 and num_det > keep_top_k: score_index = [] - for c, indices in selected_indices: - for idx in indices: - score_index.append((scores[c][idx], c, idx)) + for c, idx in selected_indices: + score_index.append((scores[c][idx], c, idx)) sorted_score_index = sorted( score_index, key=lambda tup: tup[0], reverse=True) @@ -134,19 +134,16 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold, keep_top_k) lod.append(lod[-1] + len(nmsed_outs)) if len(nmsed_outs) == 0: continue - for c, indices in nmsed_outs: - for idx in indices: - xmin, ymin, xmax, ymax = boxes[idx][:] - det_outs.append( - (c, scores[n][c][idx], c, xmin, ymin, xmax, ymax)) + for c, idx in nmsed_outs: + xmin, ymin, xmax, ymax = boxes[idx][:] + det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax]) return det_outs, lod class TestMulticlassNMSOp(OpTest): def setUp(self): - self.op_type = 'multiclass_nms' N = 7 - M = 1230 + M = 1240 C = 21 BOX_SIZE = 4 background = 0 @@ -155,7 +152,17 @@ class TestMulticlassNMSOp(OpTest): keep_top_k = 200 score_threshold = 0.01 - scores = np.random.random((N, C, M)).astype('float32') + scores = np.random.random((N * M, C)).astype('float32') + + def softmax(x): + shiftx = x - np.max(x).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + scores = np.apply_along_axis(softmax, 1, scores) + scores = np.reshape(scores, (N, M, C)) + scores = np.transpose(scores, (0, 2, 1)) + boxes = np.random.random((M, BOX_SIZE)).astype('float32') boxes[:, 0:2] = boxes[:, 0:2] * 0.5 boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 @@ -163,8 +170,19 @@ class TestMulticlassNMSOp(OpTest): nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, nms_top_k, keep_top_k) + nmsed_outs = np.array(nmsed_outs).astype('float32') + + self.op_type = 'multiclass_nms' self.inputs = {'Bboxes': boxes, 'Scores': scores} self.outputs = {'Out': (nmsed_outs, [lod])} + self.attrs = { + 'background_label': 0, + 'nms_threshold': nms_threshold, + 'nms_top_k': nms_top_k, + 'keep_top_k': keep_top_k, + 'score_threshold': score_threshold, + 'nms_eta': 1.0, + } def test_check_output(self): self.check_output() @@ -182,18 +200,3 @@ class TestIOU(unittest.TestCase): if __name__ == '__main__': unittest.main() - # N = 7 - # M = 8 - # C = 5 - # BOX_SIZE = 4 - # background = 0 - # nms_threshold = 0.3 - # nms_top_k = 400 - # keep_top_k = 200 - # score_threshold = 0.5 - - # scores = np.random.random((N, C, M)).astype('float32') - # boxes = np.random.random((M, BOX_SIZE)).astype('float32') - # boxes[:, 0 : 2] = boxes[:, 0 : 2] * 0.5 - # boxes[:, 2 : 4] = boxes[:, 0 : 2] * 0.5 + 0.5 - # print nmsed_outs, lod From 537886408863f68d7863e8245d746d2c15ef55dd Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 30 Jan 2018 21:30:16 +0800 Subject: [PATCH 4/6] Fix the output order and add more unit test cases. --- paddle/operators/multiclass_nms_op.cc | 16 +++-- .../v2/fluid/tests/test_multiclass_nms_op.py | 68 +++++++++++++------ 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index 93c8b5216f..4689306d24 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -201,8 +201,8 @@ class MulticlassNMSKernel : public framework::OpKernel { } } // Keep top k results per image. - std::sort(score_index_pairs.begin(), score_index_pairs.end(), - SortScorePairDescend>); + std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); score_index_pairs.resize(keep_top_k); // Store the new indices. @@ -269,7 +269,8 @@ class MulticlassNMSKernel : public framework::OpKernel { int num_kept = batch_starts.back(); if (num_kept == 0) { - outs->Resize({0, 0}); + T* od = outs->mutable_data({1}, ctx.GetPlace()); + od[0] = -1; } else { outs->mutable_data({num_kept, kOutputDim}, ctx.GetPlace()); for (int64_t i = 0; i < batch_size; ++i) { @@ -349,11 +350,16 @@ is larger than -1. Then this operator pruns away boxes that have high IOU (intersection over union) overlap with already selected boxes by adaptive threshold NMS based on parameters of nms_threshold and nms_eta. -Aftern NMS step, only at most keep_top_k number of total bboxes are to be kept +Aftern NMS step, at most keep_top_k number of total bboxes are to be kept per image if keep_top_k is larger than -1. This operator support multi-class and batched inputs. It applying NMS -independently for each class. +independently for each class. The outputs is a 2-D LoDTenosr, for each +image, the offsets in first dimension of LoDTensor are called LoD, the number +of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0, +means there is no detected bbox for this image. If there is no detected boxes +for all images, all the elements in LoD are 0, and the Out only contains one +value which is -1. )DOC"); } }; diff --git a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py index b619c52e55..3097b8388c 100644 --- a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py +++ b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py @@ -56,8 +56,12 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): Args: boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. scores: (tensor) The class predscores for the img, Shape:[num_priors]. - overlap: (float) The overlap thresh for suppressing unnecessary boxes. - top_k: (int) The Maximum number of box preds to consider. + score_threshold: (float) The confidence thresh for filtering low + confidence boxes. + nms_threshold: (float) The overlap thresh for suppressing unnecessary + boxes. + top_k: (int) The maximum number of box preds to consider. + eta: (float) The parameter for adaptive NMS. Return: The indices of the kept boxes with respect to num_priors. """ @@ -67,7 +71,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): selected_indices = selected_indices.flatten() all_scores = all_scores[selected_indices] - sorted_indices = np.argsort(-all_scores, axis=0) + sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort') sorted_scores = all_scores[sorted_indices] if top_k > -1 and top_k < sorted_indices.shape[0]: sorted_indices = sorted_indices[:top_k] @@ -97,29 +101,33 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, class_num = scores.shape[0] priorbox_num = scores.shape[1] - selected_indices = [] + selected_indices = {} num_det = 0 for c in range(class_num): if c == background: continue indices = nms(boxes, scores[c], score_threshold, nms_threshold, nms_top_k) - for idx in indices: - selected_indices.append((c, idx)) + selected_indices[c] = indices num_det += len(indices) if keep_top_k > -1 and num_det > keep_top_k: score_index = [] - for c, idx in selected_indices: - score_index.append((scores[c][idx], c, idx)) + for c, indices in selected_indices.iteritems(): + for idx in indices: + score_index.append((scores[c][idx], c, idx)) sorted_score_index = sorted( score_index, key=lambda tup: tup[0], reverse=True) sorted_score_index = sorted_score_index[:keep_top_k] - selected_indices = [] + selected_indices = {} + + for _, c, _ in sorted_score_index: + selected_indices[c] = [] for s, c, idx in sorted_score_index: - selected_indices.append((c, idx)) + selected_indices[c].append(idx) + num_det = keep_top_k - return selected_indices + return selected_indices, num_det def batched_multiclass_nms(boxes, scores, background, score_threshold, @@ -129,28 +137,36 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold, det_outs = [] lod = [0] for n in range(batch_size): - nmsed_outs = multiclass_nms(boxes, scores[n], background, - score_threshold, nms_threshold, nms_top_k, - keep_top_k) - lod.append(lod[-1] + len(nmsed_outs)) - if len(nmsed_outs) == 0: continue - for c, idx in nmsed_outs: - xmin, ymin, xmax, ymax = boxes[idx][:] - det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax]) + nmsed_outs, nmsed_num = multiclass_nms(boxes, scores[n], background, + score_threshold, nms_threshold, + nms_top_k, keep_top_k) + lod.append(lod[-1] + nmsed_num) + if nmsed_num == 0: continue + + for c, indices in nmsed_outs.iteritems(): + for idx in indices: + xmin, ymin, xmax, ymax = boxes[idx][:] + det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax]) + return det_outs, lod class TestMulticlassNMSOp(OpTest): + def set_argument(self): + self.score_threshold = 0.01 + def setUp(self): + self.set_argument() N = 7 - M = 1240 + M = 1200 C = 21 BOX_SIZE = 4 + background = 0 nms_threshold = 0.3 nms_top_k = 400 keep_top_k = 200 - score_threshold = 0.01 + score_threshold = self.score_threshold scores = np.random.random((N * M, C)).astype('float32') @@ -165,11 +181,12 @@ class TestMulticlassNMSOp(OpTest): boxes = np.random.random((M, BOX_SIZE)).astype('float32') boxes[:, 0:2] = boxes[:, 0:2] * 0.5 - boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 + boxes[:, 2:4] = boxes[:, 2:4] * 0.5 + 0.5 nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, nms_top_k, keep_top_k) + nmsed_outs = [-1] if not nmsed_outs else nmsed_outs nmsed_outs = np.array(nmsed_outs).astype('float32') self.op_type = 'multiclass_nms' @@ -188,6 +205,13 @@ class TestMulticlassNMSOp(OpTest): self.check_output() +class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp): + def set_argument(self): + # Here set 2.0 to test the case there is no outputs. + # In practical use, 0.0 < score_threshold < 1.0 + self.score_threshold = 2.0 + + class TestIOU(unittest.TestCase): def test_iou(self): box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32') From f3415ec55e1daf437080d5ee2febb18b6bcb3a09 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 1 Feb 2018 21:53:16 +0800 Subject: [PATCH 5/6] Follow comments. --- paddle/operators/bipartite_match_op.cc | 18 ++- paddle/operators/multiclass_nms_op.cc | 104 ++++++++++-------- .../v2/fluid/tests/test_bipartite_match_op.py | 4 +- .../v2/fluid/tests/test_multiclass_nms_op.py | 2 +- 4 files changed, 72 insertions(+), 56 deletions(-) diff --git a/paddle/operators/bipartite_match_op.cc b/paddle/operators/bipartite_match_op.cc index 83c8778fe4..1e6fa2091d 100644 --- a/paddle/operators/bipartite_match_op.cc +++ b/paddle/operators/bipartite_match_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -28,12 +28,18 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("DistMat"), "Input(DistMat) of BipartiteMatch should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("ColToRowMatchIndices"), + "Output(ColToRowMatchIndices) of BipartiteMatch should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("ColToRowMatchDist"), + "Output(ColToRowMatchDist) of BipartiteMatch should not be null."); auto dims = ctx->GetInputDim("DistMat"); PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2."); ctx->SetOutputDim("ColToRowMatchIndices", dims); - ctx->SetOutputDim("ColToRowMatchDis", dims); + ctx->SetOutputDim("ColToRowMatchDist", dims); } }; @@ -91,7 +97,7 @@ class BipartiteMatchKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* dist_mat = context.Input("DistMat"); auto* match_indices = context.Output("ColToRowMatchIndices"); - auto* match_dist = context.Output("ColToRowMatchDis"); + auto* match_dist = context.Output("ColToRowMatchDist"); auto& dev_ctx = context.device_context(); @@ -148,13 +154,13 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker { "Otherwise, it means B[j] is matched to row " "ColToRowMatchIndices[i][j] in i-th instance. The row number of " "i-th instance is saved in ColToRowMatchIndices[i][j]."); - AddOutput("ColToRowMatchDis", + AddOutput("ColToRowMatchDist", "(Tensor) A 2-D Tensor with shape [N, M] in float type. " "N is batch size. If ColToRowMatchIndices[i][j] is -1, " - "ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " + "ColToRowMatchDist[i][j] is also -1.0. Otherwise, assumed " "ColToRowMatchIndices[i][j] = d, and the row offsets of each " "instance are called LoD. Then " - "ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]"); + "ColToRowMatchDist[i][j] = DistMat[d+LoD[i]][j]"); AddComment(R"DOC( This operator is a greedy bipartite matching algorithm, which is used to obtain the matching with the maximum distance based on the input diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index 4689306d24..cb38e9fa20 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,25 +24,33 @@ using LoDTensor = framework::LoDTensor; constexpr int64_t kOutputDim = 6; constexpr int64_t kBBoxSize = 4; -class MulticlassNMSOp : public framework::OperatorWithKernel { +class MultiClassNMSOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Bboxes"), - "Input(Bboxes) of MulticlassNMS should not be null."); + PADDLE_ENFORCE(ctx->HasInput("BBoxes"), + "Input(BBoxes) of MultiClassNMS should not be null."); PADDLE_ENFORCE(ctx->HasInput("Scores"), - "Input(Scores) of MulticlassNMS should not be null."); + "Input(Scores) of MultiClassNMS should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MultiClassNMS should not be null."); - auto box_dims = ctx->GetInputDim("Bboxes"); + auto box_dims = ctx->GetInputDim("BBoxes"); auto score_dims = ctx->GetInputDim("Scores"); PADDLE_ENFORCE_EQ(box_dims.size(), 2, - "The rank of Input(Bboxes) must be 3."); + "The rank of Input(BBoxes) must be 2."); PADDLE_ENFORCE_EQ(score_dims.size(), 3, "The rank of Input(Scores) must be 3."); - PADDLE_ENFORCE_EQ(box_dims[1], 4); - PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); + PADDLE_ENFORCE_EQ(box_dims[1], 4, + "The 2nd dimension of Input(BBoxes) must be 4, " + "represents the layout of coordinate " + "[xmin, ymin, xmax, ymax]"); + PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2], + "The 1st dimensiong of Input(BBoxes) must be equal to " + "3rd dimension of Input(Scores), which represents the " + "predicted bboxes."); // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. @@ -86,15 +94,16 @@ static inline void GetMaxScoreIndex( template T BBoxArea(const T* box, const bool normalized) { if (box[2] < box[0] || box[3] < box[1]) { - // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. - return T(0.); + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); } else { const T w = box[2] - box[0]; const T h = box[3] - box[1]; if (normalized) { return w * h; } else { - // If bbox is not within range [0, 1]. + // If coordinate values are not within range [0, 1]. return (w + 1) * (h + 1); } } @@ -121,7 +130,7 @@ static inline T JaccardOverlap(const T* box1, const T* box2, } template -class MulticlassNMSKernel : public framework::OpKernel { +class MultiClassNMSKernel : public framework::OpKernel { public: void NMSFast(const Tensor& bbox, const Tensor& scores, const T score_threshold, const T nms_threshold, const T eta, @@ -163,10 +172,10 @@ class MulticlassNMSKernel : public framework::OpKernel { } } - void MulticlassNMS(const framework::ExecutionContext& ctx, + void MultiClassNMS(const framework::ExecutionContext& ctx, const Tensor& scores, const Tensor& bboxes, - std::map>* indices, - int* num_nmsed_out) const { + std::map>& indices, + int& num_nmsed_out) const { int64_t background_label = ctx.Attr("background_label"); int64_t nms_top_k = ctx.Attr("nms_top_k"); int64_t keep_top_k = ctx.Attr("keep_top_k"); @@ -181,15 +190,15 @@ class MulticlassNMSKernel : public framework::OpKernel { if (c == background_label) continue; Tensor score = scores.Slice(c, c + 1); NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, - &((*indices)[c])); - num_det += (*indices)[c].size(); + &(indices[c])); + num_det += indices[c].size(); } - *num_nmsed_out = num_det; + num_nmsed_out = num_det; const T* scores_data = scores.data(); if (keep_top_k > -1 && num_det > keep_top_k) { std::vector>> score_index_pairs; - for (const auto& it : *indices) { + for (const auto& it : indices) { int label = it.first; const T* sdata = scores_data + label * predict_dim; const std::vector& label_indices = it.second; @@ -212,12 +221,12 @@ class MulticlassNMSKernel : public framework::OpKernel { int idx = score_index_pairs[j].second.second; new_indices[label].push_back(idx); } - new_indices.swap(*indices); - *num_nmsed_out = keep_top_k; + new_indices.swap(indices); + num_nmsed_out = keep_top_k; } } - void MulticlassOutput(const Tensor& scores, const Tensor& bboxes, + void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, std::map>& selected_indices, Tensor* outs) const { int predict_dim = scores.dims()[1]; @@ -229,23 +238,21 @@ class MulticlassNMSKernel : public framework::OpKernel { for (const auto& it : selected_indices) { int label = it.first; const T* sdata = scores_data + label * predict_dim; - std::vector indices = it.second; + const std::vector& indices = it.second; for (int j = 0; j < indices.size(); ++j) { int idx = indices[j]; const T* bdata = bboxes_data + idx * kBBoxSize; odata[count * kOutputDim] = label; // label odata[count * kOutputDim + 1] = sdata[idx]; // score - odata[count * kOutputDim + 2] = bdata[0]; // xmin - odata[count * kOutputDim + 3] = bdata[1]; // ymin - odata[count * kOutputDim + 4] = bdata[2]; // xmax - odata[count * kOutputDim + 5] = bdata[3]; // ymax + // xmin, ymin, xmax, ymax + std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T)); count++; } } } void Compute(const framework::ExecutionContext& ctx) const override { - auto* boxes = ctx.Input("Bboxes"); + auto* boxes = ctx.Input("BBoxes"); auto* scores = ctx.Input("Scores"); auto* outs = ctx.Output("Out"); @@ -262,7 +269,7 @@ class MulticlassNMSKernel : public framework::OpKernel { ins_score.Resize({class_num, predict_dim}); std::map> indices; int num_nmsed_out = 0; - MulticlassNMS(ctx, ins_score, *boxes, &indices, &num_nmsed_out); + MultiClassNMS(ctx, ins_score, *boxes, indices, num_nmsed_out); all_indices.push_back(indices); batch_starts.push_back(batch_starts.back() + num_nmsed_out); } @@ -280,7 +287,7 @@ class MulticlassNMSKernel : public framework::OpKernel { int64_t e = batch_starts[i + 1]; if (e > s) { Tensor out = outs->Slice(s, e); - MulticlassOutput(ins_score, *boxes, all_indices[i], &out); + MultiClassOutput(ins_score, *boxes, all_indices[i], &out); } } } @@ -292,28 +299,31 @@ class MulticlassNMSKernel : public framework::OpKernel { } }; -class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { +class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker { public: - MulticlassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) + MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Bboxes", - "(Tensor) A 2-D Tensor with shape [M, 4] represents the location " - "predictions with M bboxes. 4 is the number of " - "each location coordinates."); + AddInput("BBoxes", + "(Tensor) A 2-D Tensor with shape [M, 4] represents the " + "predicted locations of M bounding bboxes. Each bounding box " + "has four coordinate values and the layout is " + "[xmin, ymin, xmax, ymax]."); AddInput("Scores", "(Tensor) A 3-D Tensor with shape [N, C, M] represents the " - "confidence predictions. N is the batch size, C is the class " - "number, M is number of predictions for each class, which is " - "the same with Bboxes."); + "predicted confidence predictions. N is the batch size, C is the " + "class number, M is number of bounding boxes. For each category " + "there are total M scores which corresponding M bounding boxes. " + " Please note, M is equal to the 1st dimension of BBoxes. "); AddAttr( "background_label", "(int64_t, defalut: 0) " - "The index of background label, the background label will be ignored.") + "The index of background label, the background label will be ignored. " + "If set to -1, then all categories will be considered.") .SetDefault(0); AddAttr("score_threshold", "(float) " - "Only consider detections whose confidences are larger than " - "a threshold. If not provided, consider all boxes."); + "Threshold to filter out bounding boxes with low " + "confidence score. If not provided, consider all boxes."); AddAttr("nms_top_k", "(int64_t) " "Maximum number of detections to be kept according to the " @@ -368,8 +378,8 @@ value which is -1. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(multiclass_nms, ops::MulticlassNMSOp, - ops::MulticlassNMSOpMaker, +REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp, + ops::MultiClassNMSOpMaker, paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MulticlassNMSKernel, - ops::MulticlassNMSKernel); +REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel, + ops::MultiClassNMSKernel); diff --git a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py index c35fb20b10..4943bbb338 100644 --- a/python/paddle/v2/fluid/tests/test_bipartite_match_op.py +++ b/python/paddle/v2/fluid/tests/test_bipartite_match_op.py @@ -72,7 +72,7 @@ class TestBipartiteMatchOpWithLoD(OpTest): self.inputs = {'DistMat': (dist, lod)} self.outputs = { 'ColToRowMatchIndices': (match_indices), - 'ColToRowMatchDis': (match_dist), + 'ColToRowMatchDist': (match_dist), } def test_check_output(self): @@ -89,7 +89,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest): self.inputs = {'DistMat': dist} self.outputs = { 'ColToRowMatchIndices': match_indices, - 'ColToRowMatchDis': match_dist, + 'ColToRowMatchDist': match_dist, } def test_check_output(self): diff --git a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py index 3097b8388c..3b80d2359b 100644 --- a/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py +++ b/python/paddle/v2/fluid/tests/test_multiclass_nms_op.py @@ -190,7 +190,7 @@ class TestMulticlassNMSOp(OpTest): nmsed_outs = np.array(nmsed_outs).astype('float32') self.op_type = 'multiclass_nms' - self.inputs = {'Bboxes': boxes, 'Scores': scores} + self.inputs = {'BBoxes': boxes, 'Scores': scores} self.outputs = {'Out': (nmsed_outs, [lod])} self.attrs = { 'background_label': 0, From a6f3846d8ff1b9a9d6361381447d1ab7cab7f7ec Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Fri, 2 Feb 2018 16:33:33 +0800 Subject: [PATCH 6/6] Remove the redundant header file and make one function inlne. --- paddle/operators/multiclass_nms_op.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/operators/multiclass_nms_op.cc b/paddle/operators/multiclass_nms_op.cc index cb38e9fa20..8a65fe69f1 100644 --- a/paddle/operators/multiclass_nms_op.cc +++ b/paddle/operators/multiclass_nms_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -92,7 +91,7 @@ static inline void GetMaxScoreIndex( } template -T BBoxArea(const T* box, const bool normalized) { +static inline T BBoxArea(const T* box, const bool normalized) { if (box[2] < box[0] || box[3] < box[1]) { // If coordinate values are is invalid // (e.g. xmax < xmin or ymax < ymin), return 0.