Follow comments.

emailweixu-patch-1
dangqingqing 8 years ago
parent 5378864088
commit f3415ec55e

@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -28,12 +28,18 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DistMat"), PADDLE_ENFORCE(ctx->HasInput("DistMat"),
"Input(DistMat) of BipartiteMatch should not be null."); "Input(DistMat) of BipartiteMatch should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColToRowMatchIndices"),
"Output(ColToRowMatchIndices) of BipartiteMatch should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColToRowMatchDist"),
"Output(ColToRowMatchDist) of BipartiteMatch should not be null.");
auto dims = ctx->GetInputDim("DistMat"); auto dims = ctx->GetInputDim("DistMat");
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2."); PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDis", dims); ctx->SetOutputDim("ColToRowMatchDist", dims);
} }
}; };
@ -91,7 +97,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dist_mat = context.Input<LoDTensor>("DistMat"); auto* dist_mat = context.Input<LoDTensor>("DistMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices"); auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis"); auto* match_dist = context.Output<Tensor>("ColToRowMatchDist");
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>(); auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
@ -148,13 +154,13 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise, it means B[j] is matched to row " "Otherwise, it means B[j] is matched to row "
"ColToRowMatchIndices[i][j] in i-th instance. The row number of " "ColToRowMatchIndices[i][j] in i-th instance. The row number of "
"i-th instance is saved in ColToRowMatchIndices[i][j]."); "i-th instance is saved in ColToRowMatchIndices[i][j].");
AddOutput("ColToRowMatchDis", AddOutput("ColToRowMatchDist",
"(Tensor) A 2-D Tensor with shape [N, M] in float type. " "(Tensor) A 2-D Tensor with shape [N, M] in float type. "
"N is batch size. If ColToRowMatchIndices[i][j] is -1, " "N is batch size. If ColToRowMatchIndices[i][j] is -1, "
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed " "ColToRowMatchDist[i][j] is also -1.0. Otherwise, assumed "
"ColToRowMatchIndices[i][j] = d, and the row offsets of each " "ColToRowMatchIndices[i][j] = d, and the row offsets of each "
"instance are called LoD. Then " "instance are called LoD. Then "
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]"); "ColToRowMatchDist[i][j] = DistMat[d+LoD[i]][j]");
AddComment(R"DOC( AddComment(R"DOC(
This operator is a greedy bipartite matching algorithm, which is used to This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the maximum distance based on the input obtain the matching with the maximum distance based on the input

@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -24,25 +24,33 @@ using LoDTensor = framework::LoDTensor;
constexpr int64_t kOutputDim = 6; constexpr int64_t kOutputDim = 6;
constexpr int64_t kBBoxSize = 4; constexpr int64_t kBBoxSize = 4;
class MulticlassNMSOp : public framework::OperatorWithKernel { class MultiClassNMSOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Bboxes"), PADDLE_ENFORCE(ctx->HasInput("BBoxes"),
"Input(Bboxes) of MulticlassNMS should not be null."); "Input(BBoxes) of MultiClassNMS should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scores"), PADDLE_ENFORCE(ctx->HasInput("Scores"),
"Input(Scores) of MulticlassNMS should not be null."); "Input(Scores) of MultiClassNMS should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MultiClassNMS should not be null.");
auto box_dims = ctx->GetInputDim("Bboxes"); auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores"); auto score_dims = ctx->GetInputDim("Scores");
PADDLE_ENFORCE_EQ(box_dims.size(), 2, PADDLE_ENFORCE_EQ(box_dims.size(), 2,
"The rank of Input(Bboxes) must be 3."); "The rank of Input(BBoxes) must be 2.");
PADDLE_ENFORCE_EQ(score_dims.size(), 3, PADDLE_ENFORCE_EQ(score_dims.size(), 3,
"The rank of Input(Scores) must be 3."); "The rank of Input(Scores) must be 3.");
PADDLE_ENFORCE_EQ(box_dims[1], 4); PADDLE_ENFORCE_EQ(box_dims[1], 4,
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); "The 2nd dimension of Input(BBoxes) must be 4, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax]");
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2],
"The 1st dimensiong of Input(BBoxes) must be equal to "
"3rd dimension of Input(Scores), which represents the "
"predicted bboxes.");
// Here the box_dims[0] is not the real dimension of output. // Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel. // It will be rewritten in the computing kernel.
@ -86,15 +94,16 @@ static inline void GetMaxScoreIndex(
template <class T> template <class T>
T BBoxArea(const T* box, const bool normalized) { T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) { if (box[2] < box[0] || box[3] < box[1]) {
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. // If coordinate values are is invalid
return T(0.); // (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else { } else {
const T w = box[2] - box[0]; const T w = box[2] - box[0];
const T h = box[3] - box[1]; const T h = box[3] - box[1];
if (normalized) { if (normalized) {
return w * h; return w * h;
} else { } else {
// If bbox is not within range [0, 1]. // If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1); return (w + 1) * (h + 1);
} }
} }
@ -121,7 +130,7 @@ static inline T JaccardOverlap(const T* box1, const T* box2,
} }
template <typename T> template <typename T>
class MulticlassNMSKernel : public framework::OpKernel<T> { class MultiClassNMSKernel : public framework::OpKernel<T> {
public: public:
void NMSFast(const Tensor& bbox, const Tensor& scores, void NMSFast(const Tensor& bbox, const Tensor& scores,
const T score_threshold, const T nms_threshold, const T eta, const T score_threshold, const T nms_threshold, const T eta,
@ -163,10 +172,10 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
} }
} }
void MulticlassNMS(const framework::ExecutionContext& ctx, void MultiClassNMS(const framework::ExecutionContext& ctx,
const Tensor& scores, const Tensor& bboxes, const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>* indices, std::map<int, std::vector<int>>& indices,
int* num_nmsed_out) const { int& num_nmsed_out) const {
int64_t background_label = ctx.Attr<int>("background_label"); int64_t background_label = ctx.Attr<int>("background_label");
int64_t nms_top_k = ctx.Attr<int>("nms_top_k"); int64_t nms_top_k = ctx.Attr<int>("nms_top_k");
int64_t keep_top_k = ctx.Attr<int>("keep_top_k"); int64_t keep_top_k = ctx.Attr<int>("keep_top_k");
@ -181,15 +190,15 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
if (c == background_label) continue; if (c == background_label) continue;
Tensor score = scores.Slice(c, c + 1); Tensor score = scores.Slice(c, c + 1);
NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k, NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, nms_top_k,
&((*indices)[c])); &(indices[c]));
num_det += (*indices)[c].size(); num_det += indices[c].size();
} }
*num_nmsed_out = num_det; num_nmsed_out = num_det;
const T* scores_data = scores.data<T>(); const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) { if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs; std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) { for (const auto& it : indices) {
int label = it.first; int label = it.first;
const T* sdata = scores_data + label * predict_dim; const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& label_indices = it.second; const std::vector<int>& label_indices = it.second;
@ -212,12 +221,12 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
int idx = score_index_pairs[j].second.second; int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx); new_indices[label].push_back(idx);
} }
new_indices.swap(*indices); new_indices.swap(indices);
*num_nmsed_out = keep_top_k; num_nmsed_out = keep_top_k;
} }
} }
void MulticlassOutput(const Tensor& scores, const Tensor& bboxes, void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>& selected_indices, std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) const { Tensor* outs) const {
int predict_dim = scores.dims()[1]; int predict_dim = scores.dims()[1];
@ -229,23 +238,21 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
for (const auto& it : selected_indices) { for (const auto& it : selected_indices) {
int label = it.first; int label = it.first;
const T* sdata = scores_data + label * predict_dim; const T* sdata = scores_data + label * predict_dim;
std::vector<int> indices = it.second; const std::vector<int>& indices = it.second;
for (int j = 0; j < indices.size(); ++j) { for (int j = 0; j < indices.size(); ++j) {
int idx = indices[j]; int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize; const T* bdata = bboxes_data + idx * kBBoxSize;
odata[count * kOutputDim] = label; // label odata[count * kOutputDim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score odata[count * kOutputDim + 1] = sdata[idx]; // score
odata[count * kOutputDim + 2] = bdata[0]; // xmin // xmin, ymin, xmax, ymax
odata[count * kOutputDim + 3] = bdata[1]; // ymin std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
odata[count * kOutputDim + 4] = bdata[2]; // xmax
odata[count * kOutputDim + 5] = bdata[3]; // ymax
count++; count++;
} }
} }
} }
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* boxes = ctx.Input<Tensor>("Bboxes"); auto* boxes = ctx.Input<Tensor>("BBoxes");
auto* scores = ctx.Input<Tensor>("Scores"); auto* scores = ctx.Input<Tensor>("Scores");
auto* outs = ctx.Output<LoDTensor>("Out"); auto* outs = ctx.Output<LoDTensor>("Out");
@ -262,7 +269,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
ins_score.Resize({class_num, predict_dim}); ins_score.Resize({class_num, predict_dim});
std::map<int, std::vector<int>> indices; std::map<int, std::vector<int>> indices;
int num_nmsed_out = 0; int num_nmsed_out = 0;
MulticlassNMS(ctx, ins_score, *boxes, &indices, &num_nmsed_out); MultiClassNMS(ctx, ins_score, *boxes, indices, num_nmsed_out);
all_indices.push_back(indices); all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out); batch_starts.push_back(batch_starts.back() + num_nmsed_out);
} }
@ -280,7 +287,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
int64_t e = batch_starts[i + 1]; int64_t e = batch_starts[i + 1];
if (e > s) { if (e > s) {
Tensor out = outs->Slice(s, e); Tensor out = outs->Slice(s, e);
MulticlassOutput(ins_score, *boxes, all_indices[i], &out); MultiClassOutput(ins_score, *boxes, all_indices[i], &out);
} }
} }
} }
@ -292,28 +299,31 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
} }
}; };
class MulticlassNMSOpMaker : public framework::OpProtoAndCheckerMaker { class MultiClassNMSOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
MulticlassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker) MultiClassNMSOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Bboxes", AddInput("BBoxes",
"(Tensor) A 2-D Tensor with shape [M, 4] represents the location " "(Tensor) A 2-D Tensor with shape [M, 4] represents the "
"predictions with M bboxes. 4 is the number of " "predicted locations of M bounding bboxes. Each bounding box "
"each location coordinates."); "has four coordinate values and the layout is "
"[xmin, ymin, xmax, ymax].");
AddInput("Scores", AddInput("Scores",
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " "(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
"confidence predictions. N is the batch size, C is the class " "predicted confidence predictions. N is the batch size, C is the "
"number, M is number of predictions for each class, which is " "class number, M is number of bounding boxes. For each category "
"the same with Bboxes."); "there are total M scores which corresponding M bounding boxes. "
" Please note, M is equal to the 1st dimension of BBoxes. ");
AddAttr<int>( AddAttr<int>(
"background_label", "background_label",
"(int64_t, defalut: 0) " "(int64_t, defalut: 0) "
"The index of background label, the background label will be ignored.") "The index of background label, the background label will be ignored. "
"If set to -1, then all categories will be considered.")
.SetDefault(0); .SetDefault(0);
AddAttr<float>("score_threshold", AddAttr<float>("score_threshold",
"(float) " "(float) "
"Only consider detections whose confidences are larger than " "Threshold to filter out bounding boxes with low "
"a threshold. If not provided, consider all boxes."); "confidence score. If not provided, consider all boxes.");
AddAttr<int>("nms_top_k", AddAttr<int>("nms_top_k",
"(int64_t) " "(int64_t) "
"Maximum number of detections to be kept according to the " "Maximum number of detections to be kept according to the "
@ -368,8 +378,8 @@ value which is -1.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(multiclass_nms, ops::MulticlassNMSOp, REGISTER_OPERATOR(multiclass_nms, ops::MultiClassNMSOp,
ops::MulticlassNMSOpMaker, ops::MultiClassNMSOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MulticlassNMSKernel<float>, REGISTER_OP_CPU_KERNEL(multiclass_nms, ops::MultiClassNMSKernel<float>,
ops::MulticlassNMSKernel<double>); ops::MultiClassNMSKernel<double>);

@ -72,7 +72,7 @@ class TestBipartiteMatchOpWithLoD(OpTest):
self.inputs = {'DistMat': (dist, lod)} self.inputs = {'DistMat': (dist, lod)}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': (match_indices), 'ColToRowMatchIndices': (match_indices),
'ColToRowMatchDis': (match_dist), 'ColToRowMatchDist': (match_dist),
} }
def test_check_output(self): def test_check_output(self):
@ -89,7 +89,7 @@ class TestBipartiteMatchOpWithoutLoD(OpTest):
self.inputs = {'DistMat': dist} self.inputs = {'DistMat': dist}
self.outputs = { self.outputs = {
'ColToRowMatchIndices': match_indices, 'ColToRowMatchIndices': match_indices,
'ColToRowMatchDis': match_dist, 'ColToRowMatchDist': match_dist,
} }
def test_check_output(self): def test_check_output(self):

@ -190,7 +190,7 @@ class TestMulticlassNMSOp(OpTest):
nmsed_outs = np.array(nmsed_outs).astype('float32') nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'multiclass_nms' self.op_type = 'multiclass_nms'
self.inputs = {'Bboxes': boxes, 'Scores': scores} self.inputs = {'BBoxes': boxes, 'Scores': scores}
self.outputs = {'Out': (nmsed_outs, [lod])} self.outputs = {'Out': (nmsed_outs, [lod])}
self.attrs = { self.attrs = {
'background_label': 0, 'background_label': 0,

Loading…
Cancel
Save