add collect fpn proposals op,test=develop (#16074)
* add collect fpn proposals op,test=developrevert-17080-prepare_data
parent
60be66e2c0
commit
1c6d064627
@ -0,0 +1,108 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
|
||||
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
class CollectFpnProposalsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *context) const override {
|
||||
PADDLE_ENFORCE(context->HasInputs("MultiLevelRois"),
|
||||
"Inputs(MultiLevelRois) shouldn't be null");
|
||||
PADDLE_ENFORCE(context->HasInputs("MultiLevelScores"),
|
||||
"Inputs(MultiLevelScores) shouldn't be null");
|
||||
PADDLE_ENFORCE(context->HasOutput("FpnRois"),
|
||||
"Outputs(MultiFpnRois) of DistributeOp should not be null");
|
||||
auto roi_dims = context->GetInputsDim("MultiLevelRois");
|
||||
auto score_dims = context->GetInputsDim("MultiLevelScores");
|
||||
auto post_nms_topN = context->Attrs().Get<int>("post_nms_topN");
|
||||
std::vector<int64_t> out_dims;
|
||||
for (auto &roi_dim : roi_dims) {
|
||||
PADDLE_ENFORCE_EQ(roi_dim[1], 4,
|
||||
"Second dimension of Input(MultiLevelRois) must be 4");
|
||||
}
|
||||
for (auto &score_dim : score_dims) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
score_dim[1], 1,
|
||||
"Second dimension of Input(MultiLevelScores) must be 1");
|
||||
}
|
||||
context->SetOutputDim("FpnRois", {post_nms_topN, 4});
|
||||
if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
|
||||
// in Kernel.
|
||||
context->ShareLoD("MultiLevelRois", "FpnRois");
|
||||
}
|
||||
if (context->IsRuntime()) {
|
||||
std::vector<framework::InferShapeVarPtr> roi_inputs =
|
||||
context->GetInputVarPtrs("MultiLevelRois");
|
||||
std::vector<framework::InferShapeVarPtr> score_inputs =
|
||||
context->GetInputVarPtrs("MultiLevelScores");
|
||||
for (size_t i = 0; i < roi_inputs.size(); ++i) {
|
||||
framework::Variable *roi_var =
|
||||
boost::get<framework::Variable *>(roi_inputs[i]);
|
||||
framework::Variable *score_var =
|
||||
boost::get<framework::Variable *>(score_inputs[i]);
|
||||
auto &roi_lod = roi_var->Get<LoDTensor>().lod();
|
||||
auto &score_lod = score_var->Get<LoDTensor>().lod();
|
||||
PADDLE_ENFORCE_EQ(roi_lod, score_lod,
|
||||
"Inputs(MultiLevelRois) and Inputs(MultiLevelScores) "
|
||||
"should have same lod.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
auto data_type =
|
||||
framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]);
|
||||
return framework::OpKernelType(data_type, ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("MultiLevelRois",
|
||||
"(LoDTensor) Multiple roi LoDTensors from each level in shape "
|
||||
"(N, 4), N is the number of RoIs")
|
||||
.AsDuplicable();
|
||||
AddInput("MultiLevelScores",
|
||||
"(LoDTensor) Multiple score LoDTensors from each level in shape"
|
||||
" (N, 1), N is the number of RoIs.")
|
||||
.AsDuplicable();
|
||||
AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores");
|
||||
AddAttr<int>("post_nms_topN",
|
||||
"Select post_nms_topN RoIs from"
|
||||
" all images and all fpn layers");
|
||||
AddComment(R"DOC(
|
||||
This operator concats all proposals from different images
|
||||
and different FPN levels. Then sort all of those proposals
|
||||
by objectness confidence. Select the post_nms_topN RoIs in
|
||||
total. Finally, re-sort the RoIs in the order of batch index.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(collect_fpn_proposals, ops::CollectFpnProposalsOp,
|
||||
ops::CollectFpnProposalsOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
|
||||
ops::CollectFpnProposalsOpKernel<float>,
|
||||
ops::CollectFpnProposalsOpKernel<double>);
|
@ -0,0 +1,211 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <paddle/fluid/memory/allocation/allocator.h>
|
||||
#include "cub/cub.cuh"
|
||||
#include "paddle/fluid/framework/mixed_vector.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/memcpy.h"
|
||||
#include "paddle/fluid/operators/detection/bbox_util.h"
|
||||
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
|
||||
#include "paddle/fluid/operators/gather.cu.h"
|
||||
#include "paddle/fluid/operators/math/concat_and_split.h"
|
||||
#include "paddle/fluid/operators/strided_memcpy.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
static constexpr int kNumCUDAThreads = 64;
|
||||
static constexpr int kNumMaxinumNumBlocks = 4096;
|
||||
|
||||
const int kBBoxSize = 4;
|
||||
|
||||
static inline int NumBlocks(const int N) {
|
||||
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
|
||||
kNumMaxinumNumBlocks);
|
||||
}
|
||||
|
||||
static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids,
|
||||
int* length_lod) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (nthreads);
|
||||
i += blockDim.x * gridDim.x) {
|
||||
platform::CudaAtomicAdd(length_lod + batch_ids[i], 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto roi_ins = ctx.MultiInput<LoDTensor>("MultiLevelRois");
|
||||
const auto score_ins = ctx.MultiInput<LoDTensor>("MultiLevelScores");
|
||||
auto fpn_rois = ctx.Output<LoDTensor>("FpnRois");
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
const int post_nms_topN = ctx.Attr<int>("post_nms_topN");
|
||||
|
||||
// concat inputs along axis = 0
|
||||
int roi_offset = 0;
|
||||
int score_offset = 0;
|
||||
int total_roi_num = 0;
|
||||
for (size_t i = 0; i < roi_ins.size(); ++i) {
|
||||
total_roi_num += roi_ins[i]->dims()[0];
|
||||
}
|
||||
|
||||
int real_post_num = min(post_nms_topN, total_roi_num);
|
||||
fpn_rois->mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
|
||||
Tensor concat_rois;
|
||||
Tensor concat_scores;
|
||||
T* concat_rois_data = concat_rois.mutable_data<T>(
|
||||
{total_roi_num, kBBoxSize}, dev_ctx.GetPlace());
|
||||
T* concat_scores_data =
|
||||
concat_scores.mutable_data<T>({total_roi_num, 1}, dev_ctx.GetPlace());
|
||||
Tensor roi_batch_id_list;
|
||||
roi_batch_id_list.Resize({total_roi_num});
|
||||
int* roi_batch_id_data =
|
||||
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
|
||||
int index = 0;
|
||||
int lod_size;
|
||||
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
|
||||
|
||||
for (size_t i = 0; i < roi_ins.size(); ++i) {
|
||||
auto roi_in = roi_ins[i];
|
||||
auto score_in = score_ins[i];
|
||||
auto roi_lod = roi_in->lod().back();
|
||||
lod_size = roi_lod.size() - 1;
|
||||
for (size_t n = 0; n < lod_size; ++n) {
|
||||
for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) {
|
||||
roi_batch_id_data[index++] = n;
|
||||
}
|
||||
}
|
||||
|
||||
memory::Copy(place, concat_rois_data + roi_offset, place,
|
||||
roi_in->data<T>(), roi_in->numel() * sizeof(T),
|
||||
dev_ctx.stream());
|
||||
memory::Copy(place, concat_scores_data + score_offset, place,
|
||||
score_in->data<T>(), score_in->numel() * sizeof(T),
|
||||
dev_ctx.stream());
|
||||
roi_offset += roi_in->numel();
|
||||
score_offset += score_in->numel();
|
||||
}
|
||||
|
||||
// copy batch id list to GPU
|
||||
Tensor roi_batch_id_list_gpu;
|
||||
framework::TensorCopy(roi_batch_id_list, dev_ctx.GetPlace(),
|
||||
&roi_batch_id_list_gpu);
|
||||
|
||||
Tensor index_in_t;
|
||||
int* idx_in =
|
||||
index_in_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
|
||||
platform::ForRange<platform::CUDADeviceContext> for_range_total(
|
||||
dev_ctx, total_roi_num);
|
||||
for_range_total(RangeInitFunctor{0, 1, idx_in});
|
||||
|
||||
Tensor keys_out_t;
|
||||
T* keys_out =
|
||||
keys_out_t.mutable_data<T>({total_roi_num}, dev_ctx.GetPlace());
|
||||
Tensor index_out_t;
|
||||
int* idx_out =
|
||||
index_out_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
|
||||
|
||||
// Determine temporary device storage requirements
|
||||
size_t temp_storage_bytes = 0;
|
||||
cub::DeviceRadixSort::SortPairsDescending<T, int>(
|
||||
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
|
||||
idx_out, total_roi_num);
|
||||
// Allocate temporary storage
|
||||
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
|
||||
memory::Allocator::kScratchpad);
|
||||
|
||||
// Run sorting operation
|
||||
// sort score to get corresponding index
|
||||
cub::DeviceRadixSort::SortPairsDescending<T, int>(
|
||||
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
|
||||
keys_out, idx_in, idx_out, total_roi_num);
|
||||
index_out_t.Resize({real_post_num});
|
||||
Tensor sorted_rois;
|
||||
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
|
||||
Tensor sorted_batch_id;
|
||||
sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
|
||||
GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois);
|
||||
GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t,
|
||||
&sorted_batch_id);
|
||||
|
||||
Tensor batch_index_t;
|
||||
int* batch_idx_in =
|
||||
batch_index_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
|
||||
platform::ForRange<platform::CUDADeviceContext> for_range_post(
|
||||
dev_ctx, real_post_num);
|
||||
for_range_post(RangeInitFunctor{0, 1, batch_idx_in});
|
||||
|
||||
Tensor out_id_t;
|
||||
int* out_id_data =
|
||||
out_id_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
|
||||
// Determine temporary device storage requirements
|
||||
temp_storage_bytes = 0;
|
||||
cub::DeviceRadixSort::SortPairs<int, int>(
|
||||
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
|
||||
batch_idx_in, index_out_t.data<int>(), real_post_num);
|
||||
// Allocate temporary storage
|
||||
d_temp_storage = memory::Alloc(place, temp_storage_bytes,
|
||||
memory::Allocator::kScratchpad);
|
||||
|
||||
// Run sorting operation
|
||||
// sort batch_id to get corresponding index
|
||||
cub::DeviceRadixSort::SortPairs<int, int>(
|
||||
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
|
||||
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
|
||||
|
||||
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
|
||||
|
||||
Tensor length_lod;
|
||||
int* length_lod_data =
|
||||
length_lod.mutable_data<int>({lod_size}, dev_ctx.GetPlace());
|
||||
math::SetConstant<platform::CUDADeviceContext, int> set_zero;
|
||||
set_zero(dev_ctx, &length_lod, static_cast<int>(0));
|
||||
|
||||
int blocks = NumBlocks(real_post_num);
|
||||
int threads = kNumCUDAThreads;
|
||||
|
||||
// get length-based lod by batch ids
|
||||
GetLengthLoD<<<blocks, threads>>>(real_post_num, out_id_data,
|
||||
length_lod_data);
|
||||
std::vector<int> length_lod_cpu(lod_size);
|
||||
memory::Copy(platform::CPUPlace(), length_lod_cpu.data(), place,
|
||||
length_lod_data, sizeof(int) * lod_size, dev_ctx.stream());
|
||||
dev_ctx.Wait();
|
||||
|
||||
std::vector<size_t> offset(1, 0);
|
||||
for (int i = 0; i < lod_size; ++i) {
|
||||
offset.emplace_back(offset.back() + length_lod_cpu[i]);
|
||||
}
|
||||
|
||||
framework::LoD lod;
|
||||
lod.emplace_back(offset);
|
||||
fpn_rois->set_lod(lod);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
collect_fpn_proposals,
|
||||
ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
|
||||
float>,
|
||||
ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
@ -0,0 +1,149 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
#include "paddle/fluid/operators/gather.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
const int kBoxDim = 4;
|
||||
|
||||
template <typename T>
|
||||
struct ScoreWithID {
|
||||
T score;
|
||||
int batch_id;
|
||||
int index;
|
||||
int level;
|
||||
ScoreWithID() {
|
||||
batch_id = -1;
|
||||
index = -1;
|
||||
level = -1;
|
||||
}
|
||||
ScoreWithID(T score_, int batch_id_, int index_, int level_) {
|
||||
score = score_;
|
||||
batch_id = batch_id_;
|
||||
index = index_;
|
||||
level = level_;
|
||||
}
|
||||
};
|
||||
template <typename T>
|
||||
static inline bool CompareByScore(ScoreWithID<T> a, ScoreWithID<T> b) {
|
||||
return a.score >= b.score;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline bool CompareByBatchid(ScoreWithID<T> a, ScoreWithID<T> b) {
|
||||
return a.batch_id < b.batch_id;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto multi_layer_rois =
|
||||
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelRois");
|
||||
|
||||
auto multi_layer_scores =
|
||||
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores");
|
||||
|
||||
auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois");
|
||||
|
||||
int post_nms_topN = context.Attr<int>("post_nms_topN");
|
||||
|
||||
PADDLE_ENFORCE_GE(post_nms_topN, 0UL,
|
||||
"The parameter post_nms_topN must be a positive integer");
|
||||
|
||||
// assert that the length of Rois and scores are same
|
||||
PADDLE_ENFORCE(multi_layer_rois.size() == multi_layer_scores.size(),
|
||||
"DistributeFpnProposalsOp need 1 level of LoD");
|
||||
// Check if the lod information of two LoDTensor is same
|
||||
const int num_fpn_level = multi_layer_rois.size();
|
||||
std::vector<int> integral_of_all_rois(num_fpn_level + 1, 0);
|
||||
for (int i = 0; i < num_fpn_level; ++i) {
|
||||
auto cur_rois_lod = multi_layer_rois[i]->lod().back();
|
||||
integral_of_all_rois[i + 1] =
|
||||
integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1];
|
||||
}
|
||||
|
||||
// concatenate all fpn rois scores into a list
|
||||
// create a vector to store all scores
|
||||
std::vector<ScoreWithID<T>> scores_of_all_rois(
|
||||
integral_of_all_rois[num_fpn_level], ScoreWithID<T>());
|
||||
for (int i = 0; i < num_fpn_level; ++i) {
|
||||
const T* cur_level_scores = multi_layer_scores[i]->data<T>();
|
||||
int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i];
|
||||
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
|
||||
int cur_batch_id = 0;
|
||||
for (int j = 0; j < cur_level_num; ++j) {
|
||||
if (j >= cur_scores_lod[cur_batch_id + 1]) {
|
||||
cur_batch_id++;
|
||||
}
|
||||
int cur_index = j + integral_of_all_rois[i];
|
||||
scores_of_all_rois[cur_index].score = cur_level_scores[j];
|
||||
scores_of_all_rois[cur_index].index = j;
|
||||
scores_of_all_rois[cur_index].level = i;
|
||||
scores_of_all_rois[cur_index].batch_id = cur_batch_id;
|
||||
}
|
||||
}
|
||||
// keep top post_nms_topN rois
|
||||
// sort the rois by the score
|
||||
if (post_nms_topN > integral_of_all_rois[num_fpn_level]) {
|
||||
post_nms_topN = integral_of_all_rois[num_fpn_level];
|
||||
}
|
||||
std::stable_sort(scores_of_all_rois.begin(), scores_of_all_rois.end(),
|
||||
CompareByScore<T>);
|
||||
scores_of_all_rois.resize(post_nms_topN);
|
||||
// sort by batch id
|
||||
std::stable_sort(scores_of_all_rois.begin(), scores_of_all_rois.end(),
|
||||
CompareByBatchid<T>);
|
||||
// create a pointer array
|
||||
std::vector<const T*> multi_fpn_rois_data(num_fpn_level);
|
||||
for (int i = 0; i < num_fpn_level; ++i) {
|
||||
multi_fpn_rois_data[i] = multi_layer_rois[i]->data<T>();
|
||||
}
|
||||
// initialize the outputs
|
||||
fpn_rois->mutable_data<T>({post_nms_topN, kBoxDim}, context.GetPlace());
|
||||
T* fpn_rois_data = fpn_rois->data<T>();
|
||||
std::vector<size_t> lod0(1, 0);
|
||||
int cur_batch_id = 0;
|
||||
for (int i = 0; i < post_nms_topN; ++i) {
|
||||
int cur_fpn_level = scores_of_all_rois[i].level;
|
||||
int cur_level_index = scores_of_all_rois[i].index;
|
||||
memcpy(fpn_rois_data,
|
||||
multi_fpn_rois_data[cur_fpn_level] + cur_level_index * kBoxDim,
|
||||
kBoxDim * sizeof(T));
|
||||
fpn_rois_data += kBoxDim;
|
||||
if (scores_of_all_rois[i].batch_id != cur_batch_id) {
|
||||
cur_batch_id = scores_of_all_rois[i].batch_id;
|
||||
lod0.emplace_back(i);
|
||||
}
|
||||
}
|
||||
lod0.emplace_back(post_nms_topN);
|
||||
framework::LoD lod;
|
||||
lod.emplace_back(lod0);
|
||||
fpn_rois->set_lod(lod);
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
import sys
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestCollectFPNProposalstOp(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_case()
|
||||
self.make_rois()
|
||||
self.scores_input = [('y%d' % i,
|
||||
(self.scores[i].reshape(-1, 1), self.rois_lod[i]))
|
||||
for i in range(self.num_level)]
|
||||
self.rois, self.lod = self.calc_rois_collect()
|
||||
inputs_x = [('x%d' % i, (self.roi_inputs[i][:, 1:], self.rois_lod[i]))
|
||||
for i in range(self.num_level)]
|
||||
self.inputs = {
|
||||
'MultiLevelRois': inputs_x,
|
||||
"MultiLevelScores": self.scores_input
|
||||
}
|
||||
self.attrs = {'post_nms_topN': self.post_nms_top_n, }
|
||||
self.outputs = {'FpnRois': (self.rois, [self.lod])}
|
||||
|
||||
def init_test_case(self):
|
||||
self.post_nms_top_n = 20
|
||||
self.images_shape = [100, 100]
|
||||
|
||||
def resort_roi_by_batch_id(self, rois):
|
||||
batch_id_list = rois[:, 0]
|
||||
batch_size = int(batch_id_list.max())
|
||||
sorted_rois = []
|
||||
new_lod = []
|
||||
for batch_id in range(batch_size + 1):
|
||||
sub_ind = np.where(batch_id_list == batch_id)[0]
|
||||
sub_rois = rois[sub_ind, 1:]
|
||||
sorted_rois.append(sub_rois)
|
||||
new_lod.append(len(sub_rois))
|
||||
new_rois = np.concatenate(sorted_rois)
|
||||
return new_rois, new_lod
|
||||
|
||||
def calc_rois_collect(self):
|
||||
roi_inputs = np.concatenate(self.roi_inputs)
|
||||
scores = np.concatenate(self.scores)
|
||||
inds = np.argsort(-scores)[:self.post_nms_top_n]
|
||||
rois = roi_inputs[inds, :]
|
||||
new_rois, new_lod = self.resort_roi_by_batch_id(rois)
|
||||
return new_rois, new_lod
|
||||
|
||||
def make_rois(self):
|
||||
self.num_level = 4
|
||||
self.roi_inputs = []
|
||||
self.scores = []
|
||||
self.rois_lod = [[[20, 10]], [[30, 20]], [[20, 30]], [[10, 10]]]
|
||||
for lvl in range(self.num_level):
|
||||
rois = []
|
||||
scores_pb = []
|
||||
lod = self.rois_lod[lvl][0]
|
||||
bno = 0
|
||||
for roi_num in lod:
|
||||
for i in range(roi_num):
|
||||
xywh = np.random.rand(4)
|
||||
xy1 = xywh[0:2] * 20
|
||||
wh = xywh[2:4] * (self.images_shape - xy1)
|
||||
xy2 = xy1 + wh
|
||||
roi = [bno, xy1[0], xy1[1], xy2[0], xy2[1]]
|
||||
rois.append(roi)
|
||||
bno += 1
|
||||
scores_pb.extend(list(np.random.uniform(0.0, 1.0, roi_num)))
|
||||
rois = np.array(rois).astype("float32")
|
||||
self.roi_inputs.append(rois)
|
||||
scores_pb = np.array(scores_pb).astype("float32")
|
||||
self.scores.append(scores_pb)
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "collect_fpn_proposals"
|
||||
self.set_data()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue