From c5a14ed4cdbaebb68490a28a914a094b62c35bcc Mon Sep 17 00:00:00 2001 From: wanghaox Date: Fri, 19 Jan 2018 11:31:58 +0800 Subject: [PATCH 1/4] add mine_hard_examples operator --- paddle/operators/mine_hard_examples_op.cc | 184 ++++++++++++++++++ paddle/operators/mine_hard_examples_op.h | 148 ++++++++++++++ .../fluid/tests/test_mine_hard_examples_op.py | 99 ++++++++++ 3 files changed, 431 insertions(+) create mode 100644 paddle/operators/mine_hard_examples_op.cc create mode 100755 paddle/operators/mine_hard_examples_op.h create mode 100755 python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py diff --git a/paddle/operators/mine_hard_examples_op.cc b/paddle/operators/mine_hard_examples_op.cc new file mode 100644 index 0000000000..75098d0bcd --- /dev/null +++ b/paddle/operators/mine_hard_examples_op.cc @@ -0,0 +1,184 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/mine_hard_examples_op.h" + +namespace paddle { +namespace operators { + +class MineHardExamplesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("ClsLoss"), + "Input(ClsLoss) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("MatchIndics"), + "Input(MatchIndics) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("MatchDis"), + "Input(MatchDis) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("NegIndics"), + "Output(NegIndics) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("UpdatedMatchIndics"), + "Output(UpdatedMatchIndics) of MineHardExamplesOp should not be null."); + + auto cls_loss_dims = ctx->GetInputDim("ClsLoss"); + auto idx_dims = ctx->GetInputDim("MatchIndics"); + auto dis_dims = ctx->GetInputDim("MatchDis"); + + PADDLE_ENFORCE_EQ(cls_loss_dims.size(), 2UL, + "The shape of ClsLoss is [N, Np]."); + PADDLE_ENFORCE_EQ(idx_dims.size(), 2UL, + "The shape of MatchIndics is [N, Np]."); + PADDLE_ENFORCE_EQ(dis_dims.size(), 2UL, + "The shape of MatchDis is [N, Np]."); + + if (ctx->HasInput("LocLoss")) { + auto loc_loss_dims = ctx->GetInputDim("LocLoss"); + PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL, + "The shape of LocLoss is [N, Np]."); + PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0], + "Batch size of ClsLoss and LocLoss must be the same."); + PADDLE_ENFORCE_EQ( + cls_loss_dims[1], loc_loss_dims[1], + "Prior box number of ClsLoss and LocLoss must be the same."); + } + + PADDLE_ENFORCE_EQ( + cls_loss_dims[0], idx_dims[0], + "Batch size of ClsLoss and MatchIndics must be the same."); + PADDLE_ENFORCE_EQ( + cls_loss_dims[1], idx_dims[1], + "Prior box number of ClsLoss and MatchIndics must be the same."); + + PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0], + "Batch size of ClsLoss and MatchDis must be the same."); + PADDLE_ENFORCE_EQ( + cls_loss_dims[1], idx_dims[1], + "Prior box number of ClsLoss and MatchDis must be the same."); + + auto mining_type = + GetMiningType(ctx->Attrs().Get("mining_type")); + + PADDLE_ENFORCE_NE(mining_type, MiningType::kNone, + "mining_type must be hard_example or max_negative"); + + if (mining_type == MiningType::kMaxNegative) { + auto neg_pos_ratio = ctx->Attrs().Get("neg_pos_ratio"); + auto neg_dis_threshold = ctx->Attrs().Get("neg_dis_threshold"); + PADDLE_ENFORCE_GT( + neg_pos_ratio, 0.0f, + "neg_pos_ratio must greater than zero in max_negative mode"); + PADDLE_ENFORCE_GT( + neg_dis_threshold, 0.0f, + "neg_dis_threshold must greater than zero in max_negative mode"); + } else if (mining_type == MiningType::kHardExample) { + auto sample_size = ctx->Attrs().Get("sample_size"); + PADDLE_ENFORCE_GT( + sample_size, 0, + "sample_size must greater than zero in hard_example mode"); + } + + ctx->SetOutputDim("UpdatedMatchIndics", idx_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("ClsLoss")->type()), + ctx.device_context()); + } +}; + +class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MineHardExamplesOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "ClsLoss", + "(Tensor, default Tensor), The classification loss wit shape " + "[N, Np], N is the batch size and Np is the number of prior box."); + AddInput("LocLoss", + "(Tensor, optional, default Tensor), The localization loss " + "wit shape [N, Np], N is the batch size and Np is the number of " + "prior box.") + .AsDispensable(); + AddInput("MatchIndics", + "(Tensor, Tensor), Matched indices with shape [N, Np], N is " + "the batch size and Np is the number of prior box. " + "MatchIndics[i][j] equal -1 means box[j] does not match any " + "entity, otherwise means Box[j] is matched to row."); + AddInput("MatchDis", + "(Tensor, default Tensor) Matched indices with shape [N, " + "Np], N is the batch size and Np is the number of prior box."); + AddAttr("neg_pos_ratio", + "(float) The ratio of the negative box to the positive " + "box. Use only when mining_type is equal to max_negative.") + .SetDefault(1.0); + AddAttr("neg_dis_threshold", + "(float) The negative box dis value threshold. " + "Use only when mining_type is equal to max_negative.") + .SetDefault(0.5); + AddAttr("sample_size", + "(float) The max sample size of negative box. Use only when " + "mining_type is equal to hard_example.") + .SetDefault(0); + AddAttr("mining_type", + "(float) The mining algorithm name, the value is " + "hard_example or max_negative.") + .SetDefault("max_negative") + .InEnum({"hard_example", "max_negative"}); + + AddOutput("NegIndics", + "(LoDTensor) The output of negative example indics.a lod tensor " + "with shape [Neg, 1]. The size of lod[0] is batch size, " + "and each element is the box index. " + "For example, the batch size is 2, the lod is [[0, 1, 2]], " + "the sample 0's box 1(MatchIndics[0][1]) is selected, " + "and sample 1's box 0 is selected. The output NegIndics is " + "[[1], [0]]."); + + AddOutput("UpdatedMatchIndics", + "(Tensor) The output of updated MatchIndics, a tensor with " + "shape [N, M]. Only update when mining_type is equal to " + "hard_example. The input MatchIndics elements will be update to " + "-1 when it not in the highest loss list"); + + AddComment(R"DOC( +Mine hard examples Operator. +This operator implements hard example mining to select a subset of negative box indics. +For each image, selects the box with highest losses. subject to the condition that the box cannot have +an MatchDis > neg_dis_threshold when mining_type is equals max_negative. The selected number is +min(sample_size, max_negative_box_number) when mining_type is equals hard_example, +or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is +equals max_negative, where the max_negative_box_number is the count of MatchIndics elements with value -1. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(mine_hard_examples, ops::MineHardExamplesOp, + ops::MineHardExamplesOpMaker); + +REGISTER_OP_CPU_KERNEL( + mine_hard_examples, + ops::MineHardExamplesKernel, + ops::MineHardExamplesKernel); diff --git a/paddle/operators/mine_hard_examples_op.h b/paddle/operators/mine_hard_examples_op.h new file mode 100755 index 0000000000..0a652a60c5 --- /dev/null +++ b/paddle/operators/mine_hard_examples_op.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +enum MiningType { kNone = 0, kMaxNegative, kHardExample }; + +template +bool SortScoreDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, + const float match_dis, + const float neg_dis_threshold) { + if (mining_type == MiningType::kMaxNegative) { + return match_idx == -1 && match_dis < neg_dis_threshold; + } else if (mining_type == MiningType::kHardExample) { + return true; + } else { + return false; + } +} + +MiningType GetMiningType(std::string str) { + if (str == "max_negative") { + return MiningType::kMaxNegative; + } else if (str == "hard_example") { + return MiningType::kHardExample; + } else { + return MiningType::kNone; + } +} + +template +class MineHardExamplesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_cls_loss = ctx.Input("ClsLoss"); + auto* in_loc_loss = ctx.Input("LocLoss"); + auto* in_matched_indics = ctx.Input("MatchIndics"); + auto* in_match_dis = ctx.Input("MatchDis"); + float neg_pos_ratio = ctx.Attr("neg_pos_ratio"); + T neg_dis_threshold = static_cast(ctx.Attr("neg_dis_threshold")); + int sample_size = ctx.Attr("sample_size"); + MiningType mining_type = + GetMiningType(ctx.Attr("mining_type")); + + auto out_neg_indics = ctx.Output("NegIndics"); + auto out_match_indics = ctx.Output("UpdatedMatchIndics"); + + framework::Copy(*in_matched_indics, ctx.GetPlace(), out_match_indics); + + int batch_size = in_matched_indics->dims()[0]; + int prior_num = in_matched_indics->dims()[1]; + + auto match_indices = framework::EigenMatrix::From(*in_matched_indics); + + auto match_indices_et = + framework::EigenMatrix::From(*out_match_indics); + + auto match_dis = framework::EigenMatrix::From(*in_match_dis); + auto cls_loss = framework::EigenMatrix::From(*in_cls_loss); + auto loc_loss = framework::EigenMatrix::From(*in_loc_loss); + + std::vector> all_neg_indices; + int all_neg_num = 0; + for (int n = 0; n < batch_size; ++n) { + std::vector> loss_idx; + int neg_sel = 0; + for (int m = 0; m < prior_num; ++m) { + if (IsEligibleMining(mining_type, match_indices(n, m), match_dis(n, m), + neg_dis_threshold)) { + T loss = cls_loss(n, m); + if (mining_type == MiningType::kHardExample) { + loss = cls_loss(n, m) + loc_loss(n, m); + } + loss_idx.push_back(std::make_pair(loss, m)); + ++neg_sel; + } + } + if (mining_type == MiningType::kMaxNegative) { + int num_pos = 0; + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) != -1) ++num_pos; + } + neg_sel = std::min(static_cast(num_pos * neg_pos_ratio), neg_sel); + } else if (mining_type == MiningType::kHardExample) { + neg_sel = std::min(sample_size, neg_sel); + } + std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); + std::set sel_indices; + std::vector neg_indices; + for (int n = 0; n < neg_sel; ++n) { + sel_indices.insert(loss_idx[n].second); + } + + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) > -1) { + if (mining_type == MiningType::kHardExample && + sel_indices.find(m) == sel_indices.end()) { + match_indices_et(n, m) = -1; + } + } else { + if (sel_indices.find(m) != sel_indices.end()) { + neg_indices.push_back(m); + } + } + } + all_neg_indices.push_back(neg_indices); + all_neg_num += neg_indices.size(); + } + + framework::LoD out_neg_indics_lod; + out_neg_indics_lod.resize(1); + int neg_offset = 0; + auto neg_data = out_neg_indics->mutable_data( + framework::make_ddim({all_neg_num, 1}), ctx.GetPlace()); + out_neg_indics_lod[0].push_back(neg_offset); + for (auto neg_indices : all_neg_indices) { + for (auto neg_idx : neg_indices) { + neg_data[neg_offset++] = neg_idx; + } + out_neg_indics_lod[0].push_back(neg_offset); + } + out_neg_indics->set_lod(out_neg_indics_lod); + return; + } +}; +} // namespace operators + +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py b/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py new file mode 100755 index 0000000000..e7dd04740a --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py @@ -0,0 +1,99 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +class TestMineHardExamplesOp(OpTest): + def set_data(self): + self.init_test_data() + self.inputs = { + 'ClsLoss': self.cls_loss, + 'LocLoss': self.loc_loss, + 'MatchIndics': self.match_indices, + 'MatchDis': self.match_dis + } + + self.attrs = { + 'neg_pos_ratio': self.neg_pos_ratio, + 'neg_overlap': self.neg_overlap, + 'sample_size': self.sample_size, + 'mining_type': self.mining_type + } + + self.outputs = { + 'NegIndics': (self.neg_indices, self.neg_indices_lod), + 'UpdatedMatchIndics': self.updated_match_indices + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + return + + def setUp(self): + self.op_type = "mine_hard_examples" + self.set_data() + + def init_test_data(self): + self.neg_pos_ratio = 1.0 + self.neg_overlap = 0.5 + self.sample_size = 0 + self.mining_type = "max_negative" + self.cls_loss = np.array([[0.1, 0.1, 0.3], + [0.3, 0.1, 0.1]]).astype('float32') + + self.loc_loss = np.array([[0.1, 0.2, 0.3], + [0.3, 0.4, 0.1]]).astype('float32') + + self.match_dis = np.array([[0.2, 0.4, 0.8], + [0.1, 0.9, 0.3]]).astype('float32') + + self.match_indices = np.array([[0, -1, -1], + [-1, 0, -1]]).astype('int32') + + self.updated_match_indices = self.match_indices + + self.neg_indices_lod = [[0, 1, 2]] + self.neg_indices = np.array([[1], [0]]).astype('int32') + + +class TestMineHardExamplesOpHardExample(TestMineHardExamplesOp): + def init_test_data(self): + super(TestMineHardExamplesOpHardExample, self).init_test_data() + self.mining_type = "hard_example" + self.sample_size = 2 + + self.cls_loss = np.array([[0.5, 0.1, 0.3], + [0.3, 0.1, 0.1]]).astype('float32') + + self.loc_loss = np.array([[0.2, 0.2, 0.3], + [0.3, 0.1, 0.2]]).astype('float32') + + self.match_indices = np.array([[0, -1, -1], + [-1, 0, -1]]).astype('int32') + + self.updated_match_indices = np.array([[0, -1, -1], + [-1, -1, -1]]).astype('int32') + + self.neg_indices_lod = [[0, 1, 3]] + self.neg_indices = np.array([[2], [0], [2]]).astype('int32') + + +if __name__ == '__main__': + unittest.main() From ff5570c12605b5b983cb9308ae507e1bb214143c Mon Sep 17 00:00:00 2001 From: wanghaox Date: Thu, 1 Feb 2018 11:29:26 +0800 Subject: [PATCH 2/4] update mine_hard_examples_op --- paddle/operators/mine_hard_examples_op.cc | 234 ++++++++++++++---- paddle/operators/mine_hard_examples_op.h | 148 ----------- .../fluid/tests/test_mine_hard_examples_op.py | 29 +-- 3 files changed, 202 insertions(+), 209 deletions(-) delete mode 100755 paddle/operators/mine_hard_examples_op.h diff --git a/paddle/operators/mine_hard_examples_op.cc b/paddle/operators/mine_hard_examples_op.cc index 75098d0bcd..603368f93c 100644 --- a/paddle/operators/mine_hard_examples_op.cc +++ b/paddle/operators/mine_hard_examples_op.cc @@ -12,41 +12,178 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/mine_hard_examples_op.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace operators { +enum MiningType { kNone = 0, kMaxNegative, kHardExample }; + +template +bool SortScoreDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, + const float match_dist, + const float neg_dist_threshold) { + if (mining_type == MiningType::kMaxNegative) { + return match_idx == -1 && match_dist < neg_dist_threshold; + } else if (mining_type == MiningType::kHardExample) { + return true; + } else { + return false; + } +} + +MiningType GetMiningType(std::string str) { + if (str == "max_negative") { + return MiningType::kMaxNegative; + } else if (str == "hard_example") { + return MiningType::kHardExample; + } else { + return MiningType::kNone; + } +} + +template +class MineHardExamplesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_cls_loss = ctx.Input("ClsLoss"); + auto* in_loc_loss = ctx.Input("LocLoss"); + auto* in_matched_indices = ctx.Input("MatchIndices"); + auto* in_match_dist = ctx.Input("MatchDist"); + float neg_pos_ratio = ctx.Attr("neg_pos_ratio"); + T neg_dist_threshold = + static_cast(ctx.Attr("neg_dist_threshold")); + int sample_size = ctx.Attr("sample_size"); + MiningType mining_type = + GetMiningType(ctx.Attr("mining_type")); + + auto out_neg_indices = ctx.Output("NegIndices"); + auto out_match_indices = + ctx.Output("UpdatedMatchIndices"); + + framework::Copy(*in_matched_indices, ctx.GetPlace(), out_match_indices); + + int batch_size = in_matched_indices->dims()[0]; + int prior_num = in_matched_indices->dims()[1]; + + auto match_indices = framework::EigenMatrix::From(*in_matched_indices); + + auto match_indices_et = + framework::EigenMatrix::From(*out_match_indices); + + auto match_dist = framework::EigenMatrix::From(*in_match_dist); + + const T* cls_loss = in_cls_loss->data(); + const T* loc_loss = nullptr; + if (in_loc_loss) { + loc_loss = in_loc_loss->data(); + } + + std::vector> all_neg_indices; + std::vector batch_starts = {0}; + for (int n = 0; n < batch_size; ++n) { + std::vector> loss_idx; + int neg_sel = 0; + for (int m = 0; m < prior_num; ++m) { + if (IsEligibleMining(mining_type, match_indices(n, m), match_dist(n, m), + neg_dist_threshold)) { + T loss = cls_loss[n * prior_num + m]; + if (mining_type == MiningType::kHardExample && loc_loss != nullptr) { + loss = cls_loss[n * prior_num + m] + loc_loss[n * prior_num + m]; + } + loss_idx.push_back(std::make_pair(loss, m)); + ++neg_sel; + } + } + + if (mining_type == MiningType::kMaxNegative) { + int num_pos = 0; + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) != -1) ++num_pos; + } + neg_sel = std::min(static_cast(num_pos * neg_pos_ratio), neg_sel); + } else if (mining_type == MiningType::kHardExample) { + neg_sel = std::min(sample_size, neg_sel); + } + + std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); + std::set sel_indices; + std::vector neg_indices; + std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel, + std::inserter(sel_indices, sel_indices.begin()), + [](std::pair l) -> int { + return static_cast(l.second); + }); + + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) > -1) { + if (mining_type == MiningType::kHardExample && + sel_indices.find(m) == sel_indices.end()) { + match_indices_et(n, m) = -1; + } + } else { + if (sel_indices.find(m) != sel_indices.end()) { + neg_indices.push_back(m); + } + } + } + all_neg_indices.push_back(neg_indices); + batch_starts.push_back(batch_starts.back() + neg_indices.size()); + } + + framework::LoD out_neg_indices_lod; + out_neg_indices_lod.emplace_back(batch_starts); + int neg_offset = 0; + auto neg_data = out_neg_indices->mutable_data( + framework::make_ddim({static_cast(batch_starts.back()), 1}), + ctx.GetPlace()); + + for (auto neg_indices : all_neg_indices) { + std::copy(neg_indices.begin(), neg_indices.end(), neg_data + neg_offset); + neg_offset += neg_indices.size(); + } + out_neg_indices->set_lod(out_neg_indices_lod); + return; + } +}; + class MineHardExamplesOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("ClsLoss"), "Input(ClsLoss) of MineHardExamplesOp should not be null."); PADDLE_ENFORCE( - ctx->HasInput("MatchIndics"), - "Input(MatchIndics) of MineHardExamplesOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("MatchDis"), - "Input(MatchDis) of MineHardExamplesOp should not be null."); + ctx->HasInput("MatchIndices"), + "Input(MatchIndices) of MineHardExamplesOp should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("NegIndics"), - "Output(NegIndics) of MineHardExamplesOp should not be null."); + ctx->HasInput("MatchDist"), + "Input(MatchDist) of MineHardExamplesOp should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("UpdatedMatchIndics"), - "Output(UpdatedMatchIndics) of MineHardExamplesOp should not be null."); + ctx->HasOutput("NegIndices"), + "Output(NegIndices) of MineHardExamplesOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("UpdatedMatchIndices"), + "Output(UpdatedMatchIndices) of MineHardExamplesOp should " + "not be null."); auto cls_loss_dims = ctx->GetInputDim("ClsLoss"); - auto idx_dims = ctx->GetInputDim("MatchIndics"); - auto dis_dims = ctx->GetInputDim("MatchDis"); + auto idx_dims = ctx->GetInputDim("MatchIndices"); + auto dis_dims = ctx->GetInputDim("MatchDist"); PADDLE_ENFORCE_EQ(cls_loss_dims.size(), 2UL, "The shape of ClsLoss is [N, Np]."); PADDLE_ENFORCE_EQ(idx_dims.size(), 2UL, - "The shape of MatchIndics is [N, Np]."); + "The shape of MatchIndices is [N, Np]."); PADDLE_ENFORCE_EQ(dis_dims.size(), 2UL, - "The shape of MatchDis is [N, Np]."); + "The shape of MatchDist is [N, Np]."); if (ctx->HasInput("LocLoss")) { auto loc_loss_dims = ctx->GetInputDim("LocLoss"); @@ -61,16 +198,16 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( cls_loss_dims[0], idx_dims[0], - "Batch size of ClsLoss and MatchIndics must be the same."); + "Batch size of ClsLoss and MatchIndices must be the same."); PADDLE_ENFORCE_EQ( cls_loss_dims[1], idx_dims[1], - "Prior box number of ClsLoss and MatchIndics must be the same."); + "Prior box number of ClsLoss and MatchIndices must be the same."); PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0], - "Batch size of ClsLoss and MatchDis must be the same."); + "Batch size of ClsLoss and MatchDist must be the same."); PADDLE_ENFORCE_EQ( cls_loss_dims[1], idx_dims[1], - "Prior box number of ClsLoss and MatchDis must be the same."); + "Prior box number of ClsLoss and MatchDist must be the same."); auto mining_type = GetMiningType(ctx->Attrs().Get("mining_type")); @@ -80,13 +217,13 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { if (mining_type == MiningType::kMaxNegative) { auto neg_pos_ratio = ctx->Attrs().Get("neg_pos_ratio"); - auto neg_dis_threshold = ctx->Attrs().Get("neg_dis_threshold"); + auto neg_dist_threshold = ctx->Attrs().Get("neg_dist_threshold"); PADDLE_ENFORCE_GT( neg_pos_ratio, 0.0f, "neg_pos_ratio must greater than zero in max_negative mode"); PADDLE_ENFORCE_GT( - neg_dis_threshold, 0.0f, - "neg_dis_threshold must greater than zero in max_negative mode"); + neg_dist_threshold, 0.0f, + "neg_dist_threshold must greater than zero in max_negative mode"); } else if (mining_type == MiningType::kHardExample) { auto sample_size = ctx->Attrs().Get("sample_size"); PADDLE_ENFORCE_GT( @@ -94,12 +231,12 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { "sample_size must greater than zero in hard_example mode"); } - ctx->SetOutputDim("UpdatedMatchIndics", idx_dims); + ctx->SetOutputDim("UpdatedMatchIndices", idx_dims); } protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { + const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( framework::ToDataType(ctx.Input("ClsLoss")->type()), ctx.device_context()); @@ -108,30 +245,31 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { public: - MineHardExamplesOpMaker(OpProto *proto, OpAttrChecker *op_checker) + MineHardExamplesOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "ClsLoss", - "(Tensor, default Tensor), The classification loss wit shape " + "(Tensor, default Tensor), The classification loss with shape " "[N, Np], N is the batch size and Np is the number of prior box."); AddInput("LocLoss", "(Tensor, optional, default Tensor), The localization loss " "wit shape [N, Np], N is the batch size and Np is the number of " "prior box.") .AsDispensable(); - AddInput("MatchIndics", + AddInput("MatchIndices", "(Tensor, Tensor), Matched indices with shape [N, Np], N is " "the batch size and Np is the number of prior box. " - "MatchIndics[i][j] equal -1 means box[j] does not match any " - "entity, otherwise means Box[j] is matched to row."); - AddInput("MatchDis", + "MatchIndices[i][j] equal -1 means the j-th prior box in i-th " + "instance does not match any entity, otherwise means it is " + "matched to row."); + AddInput("MatchDist", "(Tensor, default Tensor) Matched indices with shape [N, " "Np], N is the batch size and Np is the number of prior box."); AddAttr("neg_pos_ratio", "(float) The ratio of the negative box to the positive " "box. Use only when mining_type is equal to max_negative.") .SetDefault(1.0); - AddAttr("neg_dis_threshold", + AddAttr("neg_dist_threshold", "(float) The negative box dis value threshold. " "Use only when mining_type is equal to max_negative.") .SetDefault(0.5); @@ -145,29 +283,31 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault("max_negative") .InEnum({"hard_example", "max_negative"}); - AddOutput("NegIndics", - "(LoDTensor) The output of negative example indics.a lod tensor " - "with shape [Neg, 1]. The size of lod[0] is batch size, " - "and each element is the box index. " - "For example, the batch size is 2, the lod is [[0, 1, 2]], " - "the sample 0's box 1(MatchIndics[0][1]) is selected, " - "and sample 1's box 0 is selected. The output NegIndics is " - "[[1], [0]]."); - - AddOutput("UpdatedMatchIndics", - "(Tensor) The output of updated MatchIndics, a tensor with " - "shape [N, M]. Only update when mining_type is equal to " - "hard_example. The input MatchIndics elements will be update to " - "-1 when it not in the highest loss list"); + AddOutput( + "NegIndices", + "(LoDTensor) The output of negative example indices. a LoDTensor " + "with shape [Neg, 1]. The size of lod[0] minus 1 is batch size, " + "and each element is the prior box index. " + "For example, the batch size is 2, the lod is [[0, 1, 2]], " + "the sample 0's box 1(MatchIndices[0][1]) is selected, " + "and sample 1's box 0 is selected. The output NegIndices is " + "[[1], [0]]."); + + AddOutput("UpdatedMatchIndices", + "(Tensor) The output of updated MatchIndices, a tensor with " + "shape [N, Np]. Only update when mining_type is equal to " + "hard_example. The input MatchIndices elements will be update to " + "-1 when it is not in the candidate high loss list of negative " + "examples."); AddComment(R"DOC( Mine hard examples Operator. -This operator implements hard example mining to select a subset of negative box indics. +This operator implements hard example mining to select a subset of negative box indices. For each image, selects the box with highest losses. subject to the condition that the box cannot have -an MatchDis > neg_dis_threshold when mining_type is equals max_negative. The selected number is +an Matcht > neg_dist_threshold when mining_type is equals max_negative. The selected number is min(sample_size, max_negative_box_number) when mining_type is equals hard_example, or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is -equals max_negative, where the max_negative_box_number is the count of MatchIndics elements with value -1. +equals max_negative, where the max_negative_box_number is the count of MatchIndices elements with value -1. )DOC"); } }; diff --git a/paddle/operators/mine_hard_examples_op.h b/paddle/operators/mine_hard_examples_op.h deleted file mode 100755 index 0a652a60c5..0000000000 --- a/paddle/operators/mine_hard_examples_op.h +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -enum MiningType { kNone = 0, kMaxNegative, kHardExample }; - -template -bool SortScoreDescend(const std::pair& pair1, - const std::pair& pair2) { - return pair1.first > pair2.first; -} - -inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, - const float match_dis, - const float neg_dis_threshold) { - if (mining_type == MiningType::kMaxNegative) { - return match_idx == -1 && match_dis < neg_dis_threshold; - } else if (mining_type == MiningType::kHardExample) { - return true; - } else { - return false; - } -} - -MiningType GetMiningType(std::string str) { - if (str == "max_negative") { - return MiningType::kMaxNegative; - } else if (str == "hard_example") { - return MiningType::kHardExample; - } else { - return MiningType::kNone; - } -} - -template -class MineHardExamplesKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_cls_loss = ctx.Input("ClsLoss"); - auto* in_loc_loss = ctx.Input("LocLoss"); - auto* in_matched_indics = ctx.Input("MatchIndics"); - auto* in_match_dis = ctx.Input("MatchDis"); - float neg_pos_ratio = ctx.Attr("neg_pos_ratio"); - T neg_dis_threshold = static_cast(ctx.Attr("neg_dis_threshold")); - int sample_size = ctx.Attr("sample_size"); - MiningType mining_type = - GetMiningType(ctx.Attr("mining_type")); - - auto out_neg_indics = ctx.Output("NegIndics"); - auto out_match_indics = ctx.Output("UpdatedMatchIndics"); - - framework::Copy(*in_matched_indics, ctx.GetPlace(), out_match_indics); - - int batch_size = in_matched_indics->dims()[0]; - int prior_num = in_matched_indics->dims()[1]; - - auto match_indices = framework::EigenMatrix::From(*in_matched_indics); - - auto match_indices_et = - framework::EigenMatrix::From(*out_match_indics); - - auto match_dis = framework::EigenMatrix::From(*in_match_dis); - auto cls_loss = framework::EigenMatrix::From(*in_cls_loss); - auto loc_loss = framework::EigenMatrix::From(*in_loc_loss); - - std::vector> all_neg_indices; - int all_neg_num = 0; - for (int n = 0; n < batch_size; ++n) { - std::vector> loss_idx; - int neg_sel = 0; - for (int m = 0; m < prior_num; ++m) { - if (IsEligibleMining(mining_type, match_indices(n, m), match_dis(n, m), - neg_dis_threshold)) { - T loss = cls_loss(n, m); - if (mining_type == MiningType::kHardExample) { - loss = cls_loss(n, m) + loc_loss(n, m); - } - loss_idx.push_back(std::make_pair(loss, m)); - ++neg_sel; - } - } - if (mining_type == MiningType::kMaxNegative) { - int num_pos = 0; - for (int m = 0; m < prior_num; ++m) { - if (match_indices(n, m) != -1) ++num_pos; - } - neg_sel = std::min(static_cast(num_pos * neg_pos_ratio), neg_sel); - } else if (mining_type == MiningType::kHardExample) { - neg_sel = std::min(sample_size, neg_sel); - } - std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); - std::set sel_indices; - std::vector neg_indices; - for (int n = 0; n < neg_sel; ++n) { - sel_indices.insert(loss_idx[n].second); - } - - for (int m = 0; m < prior_num; ++m) { - if (match_indices(n, m) > -1) { - if (mining_type == MiningType::kHardExample && - sel_indices.find(m) == sel_indices.end()) { - match_indices_et(n, m) = -1; - } - } else { - if (sel_indices.find(m) != sel_indices.end()) { - neg_indices.push_back(m); - } - } - } - all_neg_indices.push_back(neg_indices); - all_neg_num += neg_indices.size(); - } - - framework::LoD out_neg_indics_lod; - out_neg_indics_lod.resize(1); - int neg_offset = 0; - auto neg_data = out_neg_indics->mutable_data( - framework::make_ddim({all_neg_num, 1}), ctx.GetPlace()); - out_neg_indics_lod[0].push_back(neg_offset); - for (auto neg_indices : all_neg_indices) { - for (auto neg_idx : neg_indices) { - neg_data[neg_offset++] = neg_idx; - } - out_neg_indics_lod[0].push_back(neg_offset); - } - out_neg_indics->set_lod(out_neg_indics_lod); - return; - } -}; -} // namespace operators - -} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py b/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py index e7dd04740a..c27573c3d6 100755 --- a/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py +++ b/python/paddle/v2/fluid/tests/test_mine_hard_examples_op.py @@ -1,16 +1,17 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest import numpy as np import sys @@ -24,8 +25,8 @@ class TestMineHardExamplesOp(OpTest): self.inputs = { 'ClsLoss': self.cls_loss, 'LocLoss': self.loc_loss, - 'MatchIndics': self.match_indices, - 'MatchDis': self.match_dis + 'MatchIndices': self.match_indices, + 'MatchDist': self.match_dis } self.attrs = { @@ -36,8 +37,8 @@ class TestMineHardExamplesOp(OpTest): } self.outputs = { - 'NegIndics': (self.neg_indices, self.neg_indices_lod), - 'UpdatedMatchIndics': self.updated_match_indices + 'NegIndices': (self.neg_indices, self.neg_indices_lod), + 'UpdatedMatchIndices': self.updated_match_indices } def test_check_output(self): From 4284b857cb61f9ad090044834f3c0f62c339c0b2 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Fri, 2 Feb 2018 15:45:13 +0800 Subject: [PATCH 3/4] update mine_hard_examples op --- paddle/operators/mine_hard_examples_op.cc | 52 ++++++++++++++--------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/paddle/operators/mine_hard_examples_op.cc b/paddle/operators/mine_hard_examples_op.cc index 603368f93c..2a3bd139ed 100644 --- a/paddle/operators/mine_hard_examples_op.cc +++ b/paddle/operators/mine_hard_examples_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ inline bool IsEligibleMining(const MiningType mining_type, const int match_idx, } } -MiningType GetMiningType(std::string str) { +inline MiningType GetMiningType(std::string str) { if (str == "max_negative") { return MiningType::kMaxNegative; } else if (str == "hard_example") { @@ -112,7 +112,7 @@ class MineHardExamplesKernel : public framework::OpKernel { neg_sel = std::min(sample_size, neg_sel); } - std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); + std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend); std::set sel_indices; std::vector neg_indices; std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel, @@ -121,18 +121,27 @@ class MineHardExamplesKernel : public framework::OpKernel { return static_cast(l.second); }); - for (int m = 0; m < prior_num; ++m) { - if (match_indices(n, m) > -1) { - if (mining_type == MiningType::kHardExample && - sel_indices.find(m) == sel_indices.end()) { - match_indices_et(n, m) = -1; + if (mining_type == MiningType::kHardExample) { + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) > -1) { + if (sel_indices.find(m) == sel_indices.end()) { + match_indices_et(n, m) = -1; + } + } else { + if (sel_indices.find(m) != sel_indices.end()) { + neg_indices.push_back(m); + } } - } else { - if (sel_indices.find(m) != sel_indices.end()) { + } + } else { + for (int m = 0; m < prior_num; ++m) { + if (match_indices(n, m) == -1 && + sel_indices.find(m) != sel_indices.end()) { neg_indices.push_back(m); } } } + all_neg_indices.push_back(neg_indices); batch_starts.push_back(batch_starts.back() + neg_indices.size()); } @@ -253,7 +262,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { "[N, Np], N is the batch size and Np is the number of prior box."); AddInput("LocLoss", "(Tensor, optional, default Tensor), The localization loss " - "wit shape [N, Np], N is the batch size and Np is the number of " + "with shape [N, Np], N is the batch size and Np is the number of " "prior box.") .AsDispensable(); AddInput("MatchIndices", @@ -267,15 +276,15 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { "Np], N is the batch size and Np is the number of prior box."); AddAttr("neg_pos_ratio", "(float) The ratio of the negative box to the positive " - "box. Use only when mining_type is equal to max_negative.") + "box. Use only when mining_type is max_negative.") .SetDefault(1.0); AddAttr("neg_dist_threshold", - "(float) The negative box dis value threshold. " - "Use only when mining_type is equal to max_negative.") + "(float) The negative overlap upper bound for the unmatched " + "predictions. Use only when mining_type is max_negative.") .SetDefault(0.5); AddAttr("sample_size", "(float) The max sample size of negative box. Use only when " - "mining_type is equal to hard_example.") + "mining_type is hard_example.") .SetDefault(0); AddAttr("mining_type", "(float) The mining algorithm name, the value is " @@ -295,7 +304,7 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("UpdatedMatchIndices", "(Tensor) The output of updated MatchIndices, a tensor with " - "shape [N, Np]. Only update when mining_type is equal to " + "shape [N, Np]. Only update when mining_type is " "hard_example. The input MatchIndices elements will be update to " "-1 when it is not in the candidate high loss list of negative " "examples."); @@ -303,11 +312,12 @@ class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Mine hard examples Operator. This operator implements hard example mining to select a subset of negative box indices. -For each image, selects the box with highest losses. subject to the condition that the box cannot have -an Matcht > neg_dist_threshold when mining_type is equals max_negative. The selected number is -min(sample_size, max_negative_box_number) when mining_type is equals hard_example, -or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is -equals max_negative, where the max_negative_box_number is the count of MatchIndices elements with value -1. +For each image, selects the box with highest losses. subject to the condition that the +box cannot have an Matcht > neg_dist_threshold when mining_type is max_negative. +The selected number is min(sample_size, max_negative_box_number) when mining_type is +hard_example, or min(neg_pos_ratio * positive_box_number, max_negative_box_number) +when mining_type is max_negative, where the max_negative_box_number is the count of +MatchIndices elements with value -1. )DOC"); } }; From 8137dd9b5ed0cab202006e2b7d0ab6ff4bee34df Mon Sep 17 00:00:00 2001 From: wanghaox Date: Fri, 2 Feb 2018 16:53:33 +0800 Subject: [PATCH 4/4] update mine_hard_examples_op --- paddle/operators/mine_hard_examples_op.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/paddle/operators/mine_hard_examples_op.cc b/paddle/operators/mine_hard_examples_op.cc index 2a3bd139ed..051cc24706 100644 --- a/paddle/operators/mine_hard_examples_op.cc +++ b/paddle/operators/mine_hard_examples_op.cc @@ -117,7 +117,7 @@ class MineHardExamplesKernel : public framework::OpKernel { std::vector neg_indices; std::transform(loss_idx.begin(), loss_idx.begin() + neg_sel, std::inserter(sel_indices, sel_indices.begin()), - [](std::pair l) -> int { + [](std::pair& l) -> int { return static_cast(l.second); }); @@ -134,12 +134,8 @@ class MineHardExamplesKernel : public framework::OpKernel { } } } else { - for (int m = 0; m < prior_num; ++m) { - if (match_indices(n, m) == -1 && - sel_indices.find(m) != sel_indices.end()) { - neg_indices.push_back(m); - } - } + neg_indices.resize(sel_indices.size()); + std::copy(sel_indices.begin(), sel_indices.end(), neg_indices.begin()); } all_neg_indices.push_back(neg_indices);