add generate_proposals_v2 op (#28214)

* add generate_proposals_v2 op
revert-28284-dev/pybind_version
wangguanzhong 5 years ago committed by GitHub
parent b96869bc31
commit 5262b02585
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,10 +46,12 @@ if(WITH_GPU)
set(TMPDEPS memory cub) set(TMPDEPS memory cub)
endif() endif()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS ${TMPDEPS}) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS ${TMPDEPS})
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op.cu DEPS ${TMPDEPS})
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS ${TMPDEPS}) detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS ${TMPDEPS})
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS}) detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS ${TMPDEPS})
else() else()
detection_library(generate_proposals_op SRCS generate_proposals_op.cc) detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc)
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc) detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc) detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
endif() endif()

File diff suppressed because it is too large Load Diff

@ -21,6 +21,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
struct RangeInitFunctor { struct RangeInitFunctor {
int start; int start;
int delta; int delta;
@ -125,17 +127,45 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
} }
} }
// Calculate max IoU between each box and ground-truth and
// each row represents one box
template <typename T>
void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) {
const T* iou_data = iou.data<T>();
int row = iou.dims()[0];
int col = iou.dims()[1];
T* max_iou_data = max_iou->data<T>();
for (int i = 0; i < row; ++i) {
const T* v = iou_data + i * col;
T max_v = *std::max_element(v, v + col);
max_iou_data[i] = max_v;
}
}
static void AppendProposals(framework::Tensor* dst, int64_t offset,
const framework::Tensor& src) {
auto* out_data = dst->data<void>();
auto* to_add_data = src.data<void>();
size_t size_of_t = framework::SizeOfType(src.type());
offset *= size_of_t;
std::memcpy(
reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(out_data) + offset),
to_add_data, src.numel() * size_of_t);
}
template <class T> template <class T>
void ClipTiledBoxes(const platform::DeviceContext& ctx, void ClipTiledBoxes(const platform::DeviceContext& ctx,
const framework::Tensor& im_info, const framework::Tensor& im_info,
const framework::Tensor& input_boxes, const framework::Tensor& input_boxes,
framework::Tensor* out) { framework::Tensor* out, bool is_scale = true) {
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = out->mutable_data<T>(ctx.GetPlace());
const T* im_info_data = im_info.data<T>(); const T* im_info_data = im_info.data<T>();
const T* input_boxes_data = input_boxes.data<T>(); const T* input_boxes_data = input_boxes.data<T>();
T zero(0); T zero(0);
T im_w = round(im_info_data[1] / im_info_data[2]); T im_w =
T im_h = round(im_info_data[0] / im_info_data[2]); is_scale ? round(im_info_data[1] / im_info_data[2]) : im_info_data[1];
T im_h =
is_scale ? round(im_info_data[0] / im_info_data[2]) : im_info_data[0];
for (int64_t i = 0; i < input_boxes.numel(); ++i) { for (int64_t i = 0; i < input_boxes.numel(); ++i) {
if (i % 4 == 0) { if (i % 4 == 0) {
out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero); out_data[i] = std::max(std::min(input_boxes_data[i], im_w - 1), zero);
@ -149,19 +179,101 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx,
} }
} }
// Calculate max IoU between each box and ground-truth and // Filter the box with small area
// each row represents one box template <class T>
template <typename T> void FilterBoxes(const platform::DeviceContext& ctx,
void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) { const framework::Tensor* boxes, float min_size,
const T* iou_data = iou.data<T>(); const framework::Tensor& im_info, bool is_scale,
int row = iou.dims()[0]; framework::Tensor* keep) {
int col = iou.dims()[1]; const T* im_info_data = im_info.data<T>();
T* max_iou_data = max_iou->data<T>(); const T* boxes_data = boxes->data<T>();
for (int i = 0; i < row; ++i) { keep->Resize({boxes->dims()[0]});
const T* v = iou_data + i * col; min_size = std::max(min_size, 1.0f);
T max_v = *std::max_element(v, v + col); int* keep_data = keep->mutable_data<int>(ctx.GetPlace());
max_iou_data[i] = max_v;
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (is_scale) {
ws = (boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_info_data[2] + 1;
hs =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_info_data[2] + 1;
}
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
keep->Resize({keep_len});
}
template <class T>
static void BoxCoder(const platform::DeviceContext& ctx,
framework::Tensor* all_anchors,
framework::Tensor* bbox_deltas,
framework::Tensor* variances,
framework::Tensor* proposals) {
T* proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0];
int64_t len = all_anchors->dims()[1];
auto* bbox_deltas_data = bbox_deltas->data<T>();
auto* anchor_data = all_anchors->data<T>();
const T* variances_data = nullptr;
if (variances) {
variances_data = variances->data<T>();
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
if (variances) {
bbox_center_x =
variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
anchor_center_x;
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
kBBoxClipDefault)) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
kBBoxClipDefault)) *
anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
} }
// return proposals;
} }
} // namespace operators } // namespace operators

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,229 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/fluid/memory/allocation/allocator.h>
#include <stdio.h>
#include <string>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/bbox_util.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
namespace {
template <typename T>
static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_shape,
const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4]
const Tensor &scores, // [N, 1]
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
float eta) {
// 1. pre nms
Tensor scores_sort, index_sort;
SortDescending<T>(ctx, scores, &scores_sort, &index_sort);
int num = scores.numel();
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel()
: pre_nms_top_n;
scores_sort.Resize({pre_nms_num, 1});
index_sort.Resize({pre_nms_num, 1});
// 2. box decode and clipping
Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
{
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
for_range(BoxDecodeAndClipFunctor<T>{
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
index_sort.data<int>(), im_shape.data<T>(), proposals.data<T>()});
}
// 3. filter
Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
min_size = std::max(min_size, 1.0f);
auto stream = ctx.stream();
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
proposals.data<T>(), im_shape.data<T>(), min_size, pre_nms_num,
keep_num_t.data<int>(), keep_index.data<int>(), false);
int keep_num;
const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(platform::CPUPlace(), &keep_num, gpu_place,
keep_num_t.data<int>(), sizeof(int), ctx.stream());
ctx.Wait();
keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter;
// Handle the case when there is no keep index left
if (keep_num == 0) {
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
proposals_filter.mutable_data<T>({1, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &proposals_filter, static_cast<T>(0));
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(proposals_filter, scores_filter);
}
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter);
if (nms_thresh <= 0) {
return std::make_pair(proposals_filter, scores_filter);
}
// 4. nms
Tensor keep_nms;
NMS<T>(ctx, proposals_filter, keep_index, nms_thresh, &keep_nms);
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
keep_nms.Resize({post_nms_top_n});
}
Tensor scores_nms, proposals_nms;
proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms);
GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms);
return std::make_pair(proposals_nms, scores_nms);
}
} // namespace
template <typename DeviceContext, typename T>
class CUDAGenerateProposalsV2Kernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_shape = context.Input<Tensor>("ImShape");
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Anchors", "GenerateProposals");
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
"Input", "Variances", "GenerateProposals");
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
int post_nms_top_n = context.Attr<int>("post_nms_topN");
float nms_thresh = context.Attr<float>("nms_thresh");
float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta");
PADDLE_ENFORCE_GE(eta, 1.,
platform::errors::InvalidArgument(
"Not support adaptive NMS. The attribute 'eta' "
"should not less than 1. But received eta=[%d]",
eta));
auto &dev_ctx = context.template device_context<DeviceContext>();
auto scores_dim = scores->dims();
int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3];
Tensor bbox_deltas_swap, scores_swap;
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
dev_ctx.GetPlace());
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace());
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
T *rpn_rois_data = rpn_rois->data<T>();
T *rpn_roi_probs_data = rpn_roi_probs->data<T>();
auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
auto cpu_place = platform::CPUPlace();
int64_t num_proposals = 0;
std::vector<size_t> offset(1, 0);
std::vector<int> tmp_num;
for (int64_t i = 0; i < num; ++i) {
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
Tensor scores_slice = scores_swap.Slice(i, i + 1);
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair =
ProposalForOneImage<T>(dev_ctx, im_shape_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta);
Tensor &proposals = box_score_pair.first;
Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(),
dev_ctx.stream());
memory::Copy(place, rpn_roi_probs_data + num_proposals, place,
scores.data<T>(), sizeof(T) * scores.numel(),
dev_ctx.stream());
dev_ctx.Wait();
num_proposals += proposals.dims()[0];
offset.emplace_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num,
dev_ctx.stream());
rpn_rois_num->Resize({num});
}
framework::LoD lod;
lod.emplace_back(offset);
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4});
rpn_roi_probs->Resize({num_proposals, 1});
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(generate_proposals_v2,
ops::CUDAGenerateProposalsV2Kernel<
paddle::platform::CUDADeviceContext, float>);

@ -99,5 +99,74 @@ T PolyIoU(const T* box1, const T* box2, const size_t box_size,
} }
} }
template <class T>
static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
const std::vector<T>& scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices.emplace_back(scores[i], i);
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
[](const std::pair<T, int>& a, const std::pair<T, int>& b) {
return a.first < b.first;
});
return sorted_indices;
}
template <typename T>
static inline framework::Tensor VectorToTensor(
const std::vector<T>& selected_indices, int selected_num) {
framework::Tensor keep_nms;
keep_nms.Resize({selected_num});
auto* keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T>
framework::Tensor NMS(const platform::DeviceContext& ctx,
framework::Tensor* bbox, framework::Tensor* scores,
T nms_threshold, float eta) {
int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox->dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices =
GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices;
int selected_num = 0;
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox->data<T>();
while (sorted_indices.size() != 0) {
int idx = sorted_indices.back().second;
bool flag = true;
for (int kept_idx : selected_indices) {
if (flag) {
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold);
} else {
break;
}
}
if (flag) {
selected_indices.push_back(idx);
++selected_num;
}
sorted_indices.erase(sorted_indices.end() - 1);
if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
return VectorToTensor(selected_indices, selected_num);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -81,6 +81,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}}, {"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}}, {"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"multiclass_nms3", {"Out", "NmsRoisNum"}}, {"multiclass_nms3", {"Out", "NmsRoisNum"}},
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are

@ -0,0 +1,238 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
import paddle
import paddle.fluid as fluid
from op_test import OpTest
from test_multiclass_nms_op import nms
from test_anchor_generator_op import anchor_generator_in_python
import copy
from test_generate_proposals_op import clip_tiled_boxes, box_coder, nms
def generate_proposals_v2_in_python(scores, bbox_deltas, im_shape, anchors,
variances, pre_nms_topN, post_nms_topN,
nms_thresh, min_size, eta):
all_anchors = anchors.reshape(-1, 4)
rois = np.empty((0, 5), dtype=np.float32)
roi_probs = np.empty((0, 1), dtype=np.float32)
rpn_rois = []
rpn_roi_probs = []
rois_num = []
num_images = scores.shape[0]
for img_idx in range(num_images):
img_i_boxes, img_i_probs = proposal_for_one_image(
im_shape[img_idx, :], all_anchors, variances,
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
rois_num.append(img_i_probs.shape[0])
rpn_rois.append(img_i_boxes)
rpn_roi_probs.append(img_i_probs)
return rpn_rois, rpn_roi_probs, rois_num
def proposal_for_one_image(im_shape, all_anchors, variances, bbox_deltas,
scores, pre_nms_topN, post_nms_topN, nms_thresh,
min_size, eta):
# Transpose and reshape predicted bbox transformations to get them
# into the same order as the anchors:
# - bbox deltas will be (4 * A, H, W) format from conv output
# - transpose to (H, W, 4 * A)
# - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
# in slowest to fastest order to match the enumerated anchors
bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4)
all_anchors = all_anchors.reshape(-1, 4)
variances = variances.reshape(-1, 4)
# Same story for the scores:
# - scores are (A, H, W) format from conv output
# - transpose to (H, W, A)
# - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
# to match the order of anchors and bbox_deltas
scores = scores.transpose((1, 2, 0)).reshape(-1, 1)
# sort all (proposal, score) pairs by score from highest to lowest
# take top pre_nms_topN (e.g. 6000)
if pre_nms_topN <= 0 or pre_nms_topN >= len(scores):
order = np.argsort(-scores.squeeze())
else:
# Avoid sorting possibly large arrays;
# First partition to get top K unsorted
# and then sort just those
inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN]
order = np.argsort(-scores[inds].squeeze())
order = inds[order]
scores = scores[order, :]
bbox_deltas = bbox_deltas[order, :]
all_anchors = all_anchors[order, :]
proposals = box_coder(all_anchors, bbox_deltas, variances)
# clip proposals to image (may result in proposals with zero area
# that will be removed in the next step)
proposals = clip_tiled_boxes(proposals, im_shape)
# remove predicted boxes with height or width < min_size
keep = filter_boxes(proposals, min_size, im_shape)
if len(keep) == 0:
proposals = np.zeros((1, 4)).astype('float32')
scores = np.zeros((1, 1)).astype('float32')
return proposals, scores
proposals = proposals[keep, :]
scores = scores[keep, :]
# apply loose nms (e.g. threshold = 0.7)
# take post_nms_topN (e.g. 1000)
# return the top proposals
if nms_thresh > 0:
keep = nms(boxes=proposals,
scores=scores,
nms_threshold=nms_thresh,
eta=eta)
if post_nms_topN > 0 and post_nms_topN < len(keep):
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep, :]
return proposals, scores
def filter_boxes(boxes, min_size, im_shape):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size = max(min_size, 1.0)
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_shape[1])
& (y_ctr < im_shape[0]))[0]
return keep
class TestGenerateProposalsV2Op(OpTest):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
}
def test_check_output(self):
self.check_output()
def setUp(self):
self.op_type = "generate_proposals_v2"
self.set_data()
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 3.0
self.eta = 1.
def init_test_input(self):
batch_size = 1
input_channels = 20
layer_h = 16
layer_w = 16
input_feat = np.random.random(
(batch_size, input_channels, layer_h, layer_w)).astype('float32')
self.anchors, self.variances = anchor_generator_in_python(
input_feat=input_feat,
anchor_sizes=[16., 32.],
aspect_ratios=[0.5, 1.0],
variances=[1.0, 1.0, 1.0, 1.0],
stride=[16.0, 16.0],
offset=0.5)
self.im_shape = np.array([[64, 64]]).astype('float32')
num_anchors = self.anchors.shape[2]
self.scores = np.random.random(
(batch_size, num_anchors, layer_h, layer_w)).astype('float32')
self.bbox_deltas = np.random.random(
(batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
def init_test_output(self):
self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_v2_in_python(
self.scores, self.bbox_deltas, self.im_shape, self.anchors,
self.variances, self.pre_nms_topN, self.post_nms_topN,
self.nms_thresh, self.min_size, self.eta)
class TestGenerateProposalsV2OutLodOp(TestGenerateProposalsV2Op):
def set_data(self):
self.init_test_params()
self.init_test_input()
self.init_test_output()
self.inputs = {
'Scores': self.scores,
'BboxDeltas': self.bbox_deltas,
'ImShape': self.im_shape.astype(np.float32),
'Anchors': self.anchors,
'Variances': self.variances
}
self.attrs = {
'pre_nms_topN': self.pre_nms_topN,
'post_nms_topN': self.post_nms_topN,
'nms_thresh': self.nms_thresh,
'min_size': self.min_size,
'eta': self.eta,
'return_rois_num': True
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
'RpnRoisNum': (np.asarray(
self.rois_num, dtype=np.int32))
}
class TestGenerateProposalsV2OpNoBoxLeft(TestGenerateProposalsV2Op):
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 1000.0
self.eta = 1.
if __name__ == '__main__':
paddle.enable_static()
unittest.main()

@ -673,4 +673,5 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op_xpu', 'test_sgd_op_xpu',
'test_shape_op_xpu', 'test_shape_op_xpu',
'test_slice_op_xpu', 'test_slice_op_xpu',
'test_generate_proposals_v2_op',
] ]

Loading…
Cancel
Save