parent
38c61053ff
commit
c5a14ed4cd
@ -0,0 +1,184 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/mine_hard_examples_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class MineHardExamplesOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("ClsLoss"),
|
||||
"Input(ClsLoss) of MineHardExamplesOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("MatchIndics"),
|
||||
"Input(MatchIndics) of MineHardExamplesOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("MatchDis"),
|
||||
"Input(MatchDis) of MineHardExamplesOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("NegIndics"),
|
||||
"Output(NegIndics) of MineHardExamplesOp should not be null.");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasOutput("UpdatedMatchIndics"),
|
||||
"Output(UpdatedMatchIndics) of MineHardExamplesOp should not be null.");
|
||||
|
||||
auto cls_loss_dims = ctx->GetInputDim("ClsLoss");
|
||||
auto idx_dims = ctx->GetInputDim("MatchIndics");
|
||||
auto dis_dims = ctx->GetInputDim("MatchDis");
|
||||
|
||||
PADDLE_ENFORCE_EQ(cls_loss_dims.size(), 2UL,
|
||||
"The shape of ClsLoss is [N, Np].");
|
||||
PADDLE_ENFORCE_EQ(idx_dims.size(), 2UL,
|
||||
"The shape of MatchIndics is [N, Np].");
|
||||
PADDLE_ENFORCE_EQ(dis_dims.size(), 2UL,
|
||||
"The shape of MatchDis is [N, Np].");
|
||||
|
||||
if (ctx->HasInput("LocLoss")) {
|
||||
auto loc_loss_dims = ctx->GetInputDim("LocLoss");
|
||||
PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL,
|
||||
"The shape of LocLoss is [N, Np].");
|
||||
PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0],
|
||||
"Batch size of ClsLoss and LocLoss must be the same.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cls_loss_dims[1], loc_loss_dims[1],
|
||||
"Prior box number of ClsLoss and LocLoss must be the same.");
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cls_loss_dims[0], idx_dims[0],
|
||||
"Batch size of ClsLoss and MatchIndics must be the same.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cls_loss_dims[1], idx_dims[1],
|
||||
"Prior box number of ClsLoss and MatchIndics must be the same.");
|
||||
|
||||
PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0],
|
||||
"Batch size of ClsLoss and MatchDis must be the same.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
cls_loss_dims[1], idx_dims[1],
|
||||
"Prior box number of ClsLoss and MatchDis must be the same.");
|
||||
|
||||
auto mining_type =
|
||||
GetMiningType(ctx->Attrs().Get<std::string>("mining_type"));
|
||||
|
||||
PADDLE_ENFORCE_NE(mining_type, MiningType::kNone,
|
||||
"mining_type must be hard_example or max_negative");
|
||||
|
||||
if (mining_type == MiningType::kMaxNegative) {
|
||||
auto neg_pos_ratio = ctx->Attrs().Get<float>("neg_pos_ratio");
|
||||
auto neg_dis_threshold = ctx->Attrs().Get<float>("neg_dis_threshold");
|
||||
PADDLE_ENFORCE_GT(
|
||||
neg_pos_ratio, 0.0f,
|
||||
"neg_pos_ratio must greater than zero in max_negative mode");
|
||||
PADDLE_ENFORCE_GT(
|
||||
neg_dis_threshold, 0.0f,
|
||||
"neg_dis_threshold must greater than zero in max_negative mode");
|
||||
} else if (mining_type == MiningType::kHardExample) {
|
||||
auto sample_size = ctx->Attrs().Get<int>("sample_size");
|
||||
PADDLE_ENFORCE_GT(
|
||||
sample_size, 0,
|
||||
"sample_size must greater than zero in hard_example mode");
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("UpdatedMatchIndics", idx_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("ClsLoss")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class MineHardExamplesOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
MineHardExamplesOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"ClsLoss",
|
||||
"(Tensor, default Tensor<float>), The classification loss wit shape "
|
||||
"[N, Np], N is the batch size and Np is the number of prior box.");
|
||||
AddInput("LocLoss",
|
||||
"(Tensor, optional, default Tensor<float>), The localization loss "
|
||||
"wit shape [N, Np], N is the batch size and Np is the number of "
|
||||
"prior box.")
|
||||
.AsDispensable();
|
||||
AddInput("MatchIndics",
|
||||
"(Tensor, Tensor<int>), Matched indices with shape [N, Np], N is "
|
||||
"the batch size and Np is the number of prior box. "
|
||||
"MatchIndics[i][j] equal -1 means box[j] does not match any "
|
||||
"entity, otherwise means Box[j] is matched to row.");
|
||||
AddInput("MatchDis",
|
||||
"(Tensor, default Tensor<float>) Matched indices with shape [N, "
|
||||
"Np], N is the batch size and Np is the number of prior box.");
|
||||
AddAttr<float>("neg_pos_ratio",
|
||||
"(float) The ratio of the negative box to the positive "
|
||||
"box. Use only when mining_type is equal to max_negative.")
|
||||
.SetDefault(1.0);
|
||||
AddAttr<float>("neg_dis_threshold",
|
||||
"(float) The negative box dis value threshold. "
|
||||
"Use only when mining_type is equal to max_negative.")
|
||||
.SetDefault(0.5);
|
||||
AddAttr<int>("sample_size",
|
||||
"(float) The max sample size of negative box. Use only when "
|
||||
"mining_type is equal to hard_example.")
|
||||
.SetDefault(0);
|
||||
AddAttr<std::string>("mining_type",
|
||||
"(float) The mining algorithm name, the value is "
|
||||
"hard_example or max_negative.")
|
||||
.SetDefault("max_negative")
|
||||
.InEnum({"hard_example", "max_negative"});
|
||||
|
||||
AddOutput("NegIndics",
|
||||
"(LoDTensor) The output of negative example indics.a lod tensor "
|
||||
"with shape [Neg, 1]. The size of lod[0] is batch size, "
|
||||
"and each element is the box index. "
|
||||
"For example, the batch size is 2, the lod is [[0, 1, 2]], "
|
||||
"the sample 0's box 1(MatchIndics[0][1]) is selected, "
|
||||
"and sample 1's box 0 is selected. The output NegIndics is "
|
||||
"[[1], [0]].");
|
||||
|
||||
AddOutput("UpdatedMatchIndics",
|
||||
"(Tensor) The output of updated MatchIndics, a tensor with "
|
||||
"shape [N, M]. Only update when mining_type is equal to "
|
||||
"hard_example. The input MatchIndics elements will be update to "
|
||||
"-1 when it not in the highest loss list");
|
||||
|
||||
AddComment(R"DOC(
|
||||
Mine hard examples Operator.
|
||||
This operator implements hard example mining to select a subset of negative box indics.
|
||||
For each image, selects the box with highest losses. subject to the condition that the box cannot have
|
||||
an MatchDis > neg_dis_threshold when mining_type is equals max_negative. The selected number is
|
||||
min(sample_size, max_negative_box_number) when mining_type is equals hard_example,
|
||||
or min(neg_pos_ratio * positive_box_number, max_negative_box_number) when mining_type is
|
||||
equals max_negative, where the max_negative_box_number is the count of MatchIndics elements with value -1.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(mine_hard_examples, ops::MineHardExamplesOp,
|
||||
ops::MineHardExamplesOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
mine_hard_examples,
|
||||
ops::MineHardExamplesKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::MineHardExamplesKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
enum MiningType { kNone = 0, kMaxNegative, kHardExample };
|
||||
|
||||
template <typename T>
|
||||
bool SortScoreDescend(const std::pair<float, T>& pair1,
|
||||
const std::pair<float, T>& pair2) {
|
||||
return pair1.first > pair2.first;
|
||||
}
|
||||
|
||||
inline bool IsEligibleMining(const MiningType mining_type, const int match_idx,
|
||||
const float match_dis,
|
||||
const float neg_dis_threshold) {
|
||||
if (mining_type == MiningType::kMaxNegative) {
|
||||
return match_idx == -1 && match_dis < neg_dis_threshold;
|
||||
} else if (mining_type == MiningType::kHardExample) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
MiningType GetMiningType(std::string str) {
|
||||
if (str == "max_negative") {
|
||||
return MiningType::kMaxNegative;
|
||||
} else if (str == "hard_example") {
|
||||
return MiningType::kHardExample;
|
||||
} else {
|
||||
return MiningType::kNone;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class MineHardExamplesKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* in_cls_loss = ctx.Input<framework::Tensor>("ClsLoss");
|
||||
auto* in_loc_loss = ctx.Input<framework::Tensor>("LocLoss");
|
||||
auto* in_matched_indics = ctx.Input<framework::Tensor>("MatchIndics");
|
||||
auto* in_match_dis = ctx.Input<framework::Tensor>("MatchDis");
|
||||
float neg_pos_ratio = ctx.Attr<float>("neg_pos_ratio");
|
||||
T neg_dis_threshold = static_cast<T>(ctx.Attr<float>("neg_dis_threshold"));
|
||||
int sample_size = ctx.Attr<int>("sample_size");
|
||||
MiningType mining_type =
|
||||
GetMiningType(ctx.Attr<std::string>("mining_type"));
|
||||
|
||||
auto out_neg_indics = ctx.Output<framework::LoDTensor>("NegIndics");
|
||||
auto out_match_indics = ctx.Output<framework::Tensor>("UpdatedMatchIndics");
|
||||
|
||||
framework::Copy(*in_matched_indics, ctx.GetPlace(), out_match_indics);
|
||||
|
||||
int batch_size = in_matched_indics->dims()[0];
|
||||
int prior_num = in_matched_indics->dims()[1];
|
||||
|
||||
auto match_indices = framework::EigenMatrix<int>::From(*in_matched_indics);
|
||||
|
||||
auto match_indices_et =
|
||||
framework::EigenMatrix<int>::From(*out_match_indics);
|
||||
|
||||
auto match_dis = framework::EigenMatrix<float>::From(*in_match_dis);
|
||||
auto cls_loss = framework::EigenMatrix<float>::From(*in_cls_loss);
|
||||
auto loc_loss = framework::EigenMatrix<float>::From(*in_loc_loss);
|
||||
|
||||
std::vector<std::vector<int>> all_neg_indices;
|
||||
int all_neg_num = 0;
|
||||
for (int n = 0; n < batch_size; ++n) {
|
||||
std::vector<std::pair<float, size_t>> loss_idx;
|
||||
int neg_sel = 0;
|
||||
for (int m = 0; m < prior_num; ++m) {
|
||||
if (IsEligibleMining(mining_type, match_indices(n, m), match_dis(n, m),
|
||||
neg_dis_threshold)) {
|
||||
T loss = cls_loss(n, m);
|
||||
if (mining_type == MiningType::kHardExample) {
|
||||
loss = cls_loss(n, m) + loc_loss(n, m);
|
||||
}
|
||||
loss_idx.push_back(std::make_pair(loss, m));
|
||||
++neg_sel;
|
||||
}
|
||||
}
|
||||
if (mining_type == MiningType::kMaxNegative) {
|
||||
int num_pos = 0;
|
||||
for (int m = 0; m < prior_num; ++m) {
|
||||
if (match_indices(n, m) != -1) ++num_pos;
|
||||
}
|
||||
neg_sel = std::min(static_cast<int>(num_pos * neg_pos_ratio), neg_sel);
|
||||
} else if (mining_type == MiningType::kHardExample) {
|
||||
neg_sel = std::min(sample_size, neg_sel);
|
||||
}
|
||||
std::sort(loss_idx.begin(), loss_idx.end(), SortScoreDescend<int>);
|
||||
std::set<int> sel_indices;
|
||||
std::vector<int> neg_indices;
|
||||
for (int n = 0; n < neg_sel; ++n) {
|
||||
sel_indices.insert(loss_idx[n].second);
|
||||
}
|
||||
|
||||
for (int m = 0; m < prior_num; ++m) {
|
||||
if (match_indices(n, m) > -1) {
|
||||
if (mining_type == MiningType::kHardExample &&
|
||||
sel_indices.find(m) == sel_indices.end()) {
|
||||
match_indices_et(n, m) = -1;
|
||||
}
|
||||
} else {
|
||||
if (sel_indices.find(m) != sel_indices.end()) {
|
||||
neg_indices.push_back(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
all_neg_indices.push_back(neg_indices);
|
||||
all_neg_num += neg_indices.size();
|
||||
}
|
||||
|
||||
framework::LoD out_neg_indics_lod;
|
||||
out_neg_indics_lod.resize(1);
|
||||
int neg_offset = 0;
|
||||
auto neg_data = out_neg_indics->mutable_data<int>(
|
||||
framework::make_ddim({all_neg_num, 1}), ctx.GetPlace());
|
||||
out_neg_indics_lod[0].push_back(neg_offset);
|
||||
for (auto neg_indices : all_neg_indices) {
|
||||
for (auto neg_idx : neg_indices) {
|
||||
neg_data[neg_offset++] = neg_idx;
|
||||
}
|
||||
out_neg_indics_lod[0].push_back(neg_offset);
|
||||
}
|
||||
out_neg_indics->set_lod(out_neg_indics_lod);
|
||||
return;
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,99 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestMineHardExamplesOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_data()
|
||||
self.inputs = {
|
||||
'ClsLoss': self.cls_loss,
|
||||
'LocLoss': self.loc_loss,
|
||||
'MatchIndics': self.match_indices,
|
||||
'MatchDis': self.match_dis
|
||||
}
|
||||
|
||||
self.attrs = {
|
||||
'neg_pos_ratio': self.neg_pos_ratio,
|
||||
'neg_overlap': self.neg_overlap,
|
||||
'sample_size': self.sample_size,
|
||||
'mining_type': self.mining_type
|
||||
}
|
||||
|
||||
self.outputs = {
|
||||
'NegIndics': (self.neg_indices, self.neg_indices_lod),
|
||||
'UpdatedMatchIndics': self.updated_match_indices
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
return
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "mine_hard_examples"
|
||||
self.set_data()
|
||||
|
||||
def init_test_data(self):
|
||||
self.neg_pos_ratio = 1.0
|
||||
self.neg_overlap = 0.5
|
||||
self.sample_size = 0
|
||||
self.mining_type = "max_negative"
|
||||
self.cls_loss = np.array([[0.1, 0.1, 0.3],
|
||||
[0.3, 0.1, 0.1]]).astype('float32')
|
||||
|
||||
self.loc_loss = np.array([[0.1, 0.2, 0.3],
|
||||
[0.3, 0.4, 0.1]]).astype('float32')
|
||||
|
||||
self.match_dis = np.array([[0.2, 0.4, 0.8],
|
||||
[0.1, 0.9, 0.3]]).astype('float32')
|
||||
|
||||
self.match_indices = np.array([[0, -1, -1],
|
||||
[-1, 0, -1]]).astype('int32')
|
||||
|
||||
self.updated_match_indices = self.match_indices
|
||||
|
||||
self.neg_indices_lod = [[0, 1, 2]]
|
||||
self.neg_indices = np.array([[1], [0]]).astype('int32')
|
||||
|
||||
|
||||
class TestMineHardExamplesOpHardExample(TestMineHardExamplesOp):
|
||||
def init_test_data(self):
|
||||
super(TestMineHardExamplesOpHardExample, self).init_test_data()
|
||||
self.mining_type = "hard_example"
|
||||
self.sample_size = 2
|
||||
|
||||
self.cls_loss = np.array([[0.5, 0.1, 0.3],
|
||||
[0.3, 0.1, 0.1]]).astype('float32')
|
||||
|
||||
self.loc_loss = np.array([[0.2, 0.2, 0.3],
|
||||
[0.3, 0.1, 0.2]]).astype('float32')
|
||||
|
||||
self.match_indices = np.array([[0, -1, -1],
|
||||
[-1, 0, -1]]).astype('int32')
|
||||
|
||||
self.updated_match_indices = np.array([[0, -1, -1],
|
||||
[-1, -1, -1]]).astype('int32')
|
||||
|
||||
self.neg_indices_lod = [[0, 1, 3]]
|
||||
self.neg_indices = np.array([[2], [0], [2]]).astype('int32')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue