add generate_proposals_v2 op (#28214)
* add generate_proposals_v2 oprevert-28284-dev/pybind_version
parent
b96869bc31
commit
5262b02585
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,229 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <paddle/fluid/memory/allocation/allocator.h>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/mixed_vector.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/memory/memory.h"
|
||||
#include "paddle/fluid/operators/detection/bbox_util.cu.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
static std::pair<Tensor, Tensor> ProposalForOneImage(
|
||||
const platform::CUDADeviceContext &ctx, const Tensor &im_shape,
|
||||
const Tensor &anchors, const Tensor &variances,
|
||||
const Tensor &bbox_deltas, // [M, 4]
|
||||
const Tensor &scores, // [N, 1]
|
||||
int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
|
||||
float eta) {
|
||||
// 1. pre nms
|
||||
Tensor scores_sort, index_sort;
|
||||
SortDescending<T>(ctx, scores, &scores_sort, &index_sort);
|
||||
int num = scores.numel();
|
||||
int pre_nms_num = (pre_nms_top_n <= 0 || pre_nms_top_n > num) ? scores.numel()
|
||||
: pre_nms_top_n;
|
||||
scores_sort.Resize({pre_nms_num, 1});
|
||||
index_sort.Resize({pre_nms_num, 1});
|
||||
|
||||
// 2. box decode and clipping
|
||||
Tensor proposals;
|
||||
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
|
||||
|
||||
{
|
||||
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
|
||||
for_range(BoxDecodeAndClipFunctor<T>{
|
||||
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
|
||||
index_sort.data<int>(), im_shape.data<T>(), proposals.data<T>()});
|
||||
}
|
||||
|
||||
// 3. filter
|
||||
Tensor keep_index, keep_num_t;
|
||||
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
|
||||
keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
|
||||
min_size = std::max(min_size, 1.0f);
|
||||
auto stream = ctx.stream();
|
||||
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
|
||||
proposals.data<T>(), im_shape.data<T>(), min_size, pre_nms_num,
|
||||
keep_num_t.data<int>(), keep_index.data<int>(), false);
|
||||
int keep_num;
|
||||
const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
|
||||
memory::Copy(platform::CPUPlace(), &keep_num, gpu_place,
|
||||
keep_num_t.data<int>(), sizeof(int), ctx.stream());
|
||||
ctx.Wait();
|
||||
keep_index.Resize({keep_num});
|
||||
|
||||
Tensor scores_filter, proposals_filter;
|
||||
// Handle the case when there is no keep index left
|
||||
if (keep_num == 0) {
|
||||
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
|
||||
proposals_filter.mutable_data<T>({1, 4}, ctx.GetPlace());
|
||||
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
|
||||
set_zero(ctx, &proposals_filter, static_cast<T>(0));
|
||||
set_zero(ctx, &scores_filter, static_cast<T>(0));
|
||||
return std::make_pair(proposals_filter, scores_filter);
|
||||
}
|
||||
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
|
||||
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
|
||||
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
|
||||
GPUGather<T>(ctx, scores_sort, keep_index, &scores_filter);
|
||||
|
||||
if (nms_thresh <= 0) {
|
||||
return std::make_pair(proposals_filter, scores_filter);
|
||||
}
|
||||
|
||||
// 4. nms
|
||||
Tensor keep_nms;
|
||||
NMS<T>(ctx, proposals_filter, keep_index, nms_thresh, &keep_nms);
|
||||
if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
|
||||
keep_nms.Resize({post_nms_top_n});
|
||||
}
|
||||
|
||||
Tensor scores_nms, proposals_nms;
|
||||
proposals_nms.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
|
||||
scores_nms.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
|
||||
GPUGather<T>(ctx, proposals_filter, keep_nms, &proposals_nms);
|
||||
GPUGather<T>(ctx, scores_filter, keep_nms, &scores_nms);
|
||||
|
||||
return std::make_pair(proposals_nms, scores_nms);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class CUDAGenerateProposalsV2Kernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *scores = context.Input<Tensor>("Scores");
|
||||
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
|
||||
auto *im_shape = context.Input<Tensor>("ImShape");
|
||||
auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
|
||||
"Anchors", "GenerateProposals");
|
||||
auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
|
||||
"Input", "Variances", "GenerateProposals");
|
||||
|
||||
auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
|
||||
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
|
||||
|
||||
int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
|
||||
int post_nms_top_n = context.Attr<int>("post_nms_topN");
|
||||
float nms_thresh = context.Attr<float>("nms_thresh");
|
||||
float min_size = context.Attr<float>("min_size");
|
||||
float eta = context.Attr<float>("eta");
|
||||
PADDLE_ENFORCE_GE(eta, 1.,
|
||||
platform::errors::InvalidArgument(
|
||||
"Not support adaptive NMS. The attribute 'eta' "
|
||||
"should not less than 1. But received eta=[%d]",
|
||||
eta));
|
||||
|
||||
auto &dev_ctx = context.template device_context<DeviceContext>();
|
||||
|
||||
auto scores_dim = scores->dims();
|
||||
int64_t num = scores_dim[0];
|
||||
int64_t c_score = scores_dim[1];
|
||||
int64_t h_score = scores_dim[2];
|
||||
int64_t w_score = scores_dim[3];
|
||||
|
||||
auto bbox_dim = bbox_deltas->dims();
|
||||
int64_t c_bbox = bbox_dim[1];
|
||||
int64_t h_bbox = bbox_dim[2];
|
||||
int64_t w_bbox = bbox_dim[3];
|
||||
|
||||
Tensor bbox_deltas_swap, scores_swap;
|
||||
bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
|
||||
dev_ctx.GetPlace());
|
||||
scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
|
||||
dev_ctx.GetPlace());
|
||||
|
||||
math::Transpose<DeviceContext, T, 4> trans;
|
||||
std::vector<int> axis = {0, 2, 3, 1};
|
||||
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
|
||||
trans(dev_ctx, *scores, &scores_swap, axis);
|
||||
|
||||
anchors.Resize({anchors.numel() / 4, 4});
|
||||
variances.Resize({variances.numel() / 4, 4});
|
||||
|
||||
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
|
||||
context.GetPlace());
|
||||
rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
|
||||
|
||||
T *rpn_rois_data = rpn_rois->data<T>();
|
||||
T *rpn_roi_probs_data = rpn_roi_probs->data<T>();
|
||||
|
||||
auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
|
||||
auto cpu_place = platform::CPUPlace();
|
||||
|
||||
int64_t num_proposals = 0;
|
||||
std::vector<size_t> offset(1, 0);
|
||||
std::vector<int> tmp_num;
|
||||
|
||||
for (int64_t i = 0; i < num; ++i) {
|
||||
Tensor im_shape_slice = im_shape->Slice(i, i + 1);
|
||||
Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
|
||||
Tensor scores_slice = scores_swap.Slice(i, i + 1);
|
||||
|
||||
bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
|
||||
scores_slice.Resize({h_score * w_score * c_score, 1});
|
||||
|
||||
std::pair<Tensor, Tensor> box_score_pair =
|
||||
ProposalForOneImage<T>(dev_ctx, im_shape_slice, anchors, variances,
|
||||
bbox_deltas_slice, scores_slice, pre_nms_top_n,
|
||||
post_nms_top_n, nms_thresh, min_size, eta);
|
||||
|
||||
Tensor &proposals = box_score_pair.first;
|
||||
Tensor &scores = box_score_pair.second;
|
||||
|
||||
memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
|
||||
proposals.data<T>(), sizeof(T) * proposals.numel(),
|
||||
dev_ctx.stream());
|
||||
memory::Copy(place, rpn_roi_probs_data + num_proposals, place,
|
||||
scores.data<T>(), sizeof(T) * scores.numel(),
|
||||
dev_ctx.stream());
|
||||
dev_ctx.Wait();
|
||||
num_proposals += proposals.dims()[0];
|
||||
offset.emplace_back(num_proposals);
|
||||
tmp_num.push_back(proposals.dims()[0]);
|
||||
}
|
||||
if (context.HasOutput("RpnRoisNum")) {
|
||||
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
|
||||
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
|
||||
int *num_data = rpn_rois_num->data<int>();
|
||||
memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num,
|
||||
dev_ctx.stream());
|
||||
rpn_rois_num->Resize({num});
|
||||
}
|
||||
framework::LoD lod;
|
||||
lod.emplace_back(offset);
|
||||
rpn_rois->set_lod(lod);
|
||||
rpn_roi_probs->set_lod(lod);
|
||||
rpn_rois->Resize({num_proposals, 4});
|
||||
rpn_roi_probs->Resize({num_proposals, 1});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(generate_proposals_v2,
|
||||
ops::CUDAGenerateProposalsV2Kernel<
|
||||
paddle::platform::CUDADeviceContext, float>);
|
@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from op_test import OpTest
|
||||
from test_multiclass_nms_op import nms
|
||||
from test_anchor_generator_op import anchor_generator_in_python
|
||||
import copy
|
||||
from test_generate_proposals_op import clip_tiled_boxes, box_coder, nms
|
||||
|
||||
|
||||
def generate_proposals_v2_in_python(scores, bbox_deltas, im_shape, anchors,
|
||||
variances, pre_nms_topN, post_nms_topN,
|
||||
nms_thresh, min_size, eta):
|
||||
all_anchors = anchors.reshape(-1, 4)
|
||||
rois = np.empty((0, 5), dtype=np.float32)
|
||||
roi_probs = np.empty((0, 1), dtype=np.float32)
|
||||
|
||||
rpn_rois = []
|
||||
rpn_roi_probs = []
|
||||
rois_num = []
|
||||
num_images = scores.shape[0]
|
||||
for img_idx in range(num_images):
|
||||
img_i_boxes, img_i_probs = proposal_for_one_image(
|
||||
im_shape[img_idx, :], all_anchors, variances,
|
||||
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
|
||||
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
|
||||
rois_num.append(img_i_probs.shape[0])
|
||||
rpn_rois.append(img_i_boxes)
|
||||
rpn_roi_probs.append(img_i_probs)
|
||||
|
||||
return rpn_rois, rpn_roi_probs, rois_num
|
||||
|
||||
|
||||
def proposal_for_one_image(im_shape, all_anchors, variances, bbox_deltas,
|
||||
scores, pre_nms_topN, post_nms_topN, nms_thresh,
|
||||
min_size, eta):
|
||||
# Transpose and reshape predicted bbox transformations to get them
|
||||
# into the same order as the anchors:
|
||||
# - bbox deltas will be (4 * A, H, W) format from conv output
|
||||
# - transpose to (H, W, 4 * A)
|
||||
# - reshape to (H * W * A, 4) where rows are ordered by (H, W, A)
|
||||
# in slowest to fastest order to match the enumerated anchors
|
||||
bbox_deltas = bbox_deltas.transpose((1, 2, 0)).reshape(-1, 4)
|
||||
all_anchors = all_anchors.reshape(-1, 4)
|
||||
variances = variances.reshape(-1, 4)
|
||||
# Same story for the scores:
|
||||
# - scores are (A, H, W) format from conv output
|
||||
# - transpose to (H, W, A)
|
||||
# - reshape to (H * W * A, 1) where rows are ordered by (H, W, A)
|
||||
# to match the order of anchors and bbox_deltas
|
||||
scores = scores.transpose((1, 2, 0)).reshape(-1, 1)
|
||||
|
||||
# sort all (proposal, score) pairs by score from highest to lowest
|
||||
# take top pre_nms_topN (e.g. 6000)
|
||||
if pre_nms_topN <= 0 or pre_nms_topN >= len(scores):
|
||||
order = np.argsort(-scores.squeeze())
|
||||
else:
|
||||
# Avoid sorting possibly large arrays;
|
||||
# First partition to get top K unsorted
|
||||
# and then sort just those
|
||||
inds = np.argpartition(-scores.squeeze(), pre_nms_topN)[:pre_nms_topN]
|
||||
order = np.argsort(-scores[inds].squeeze())
|
||||
order = inds[order]
|
||||
scores = scores[order, :]
|
||||
bbox_deltas = bbox_deltas[order, :]
|
||||
all_anchors = all_anchors[order, :]
|
||||
proposals = box_coder(all_anchors, bbox_deltas, variances)
|
||||
# clip proposals to image (may result in proposals with zero area
|
||||
# that will be removed in the next step)
|
||||
proposals = clip_tiled_boxes(proposals, im_shape)
|
||||
# remove predicted boxes with height or width < min_size
|
||||
keep = filter_boxes(proposals, min_size, im_shape)
|
||||
if len(keep) == 0:
|
||||
proposals = np.zeros((1, 4)).astype('float32')
|
||||
scores = np.zeros((1, 1)).astype('float32')
|
||||
return proposals, scores
|
||||
proposals = proposals[keep, :]
|
||||
scores = scores[keep, :]
|
||||
|
||||
# apply loose nms (e.g. threshold = 0.7)
|
||||
# take post_nms_topN (e.g. 1000)
|
||||
# return the top proposals
|
||||
if nms_thresh > 0:
|
||||
keep = nms(boxes=proposals,
|
||||
scores=scores,
|
||||
nms_threshold=nms_thresh,
|
||||
eta=eta)
|
||||
if post_nms_topN > 0 and post_nms_topN < len(keep):
|
||||
keep = keep[:post_nms_topN]
|
||||
proposals = proposals[keep, :]
|
||||
scores = scores[keep, :]
|
||||
|
||||
return proposals, scores
|
||||
|
||||
|
||||
def filter_boxes(boxes, min_size, im_shape):
|
||||
"""Only keep boxes with both sides >= min_size and center within the image.
|
||||
"""
|
||||
# Scale min_size to match image scale
|
||||
min_size = max(min_size, 1.0)
|
||||
ws = boxes[:, 2] - boxes[:, 0] + 1
|
||||
hs = boxes[:, 3] - boxes[:, 1] + 1
|
||||
x_ctr = boxes[:, 0] + ws / 2.
|
||||
y_ctr = boxes[:, 1] + hs / 2.
|
||||
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_shape[1])
|
||||
& (y_ctr < im_shape[0]))[0]
|
||||
return keep
|
||||
|
||||
|
||||
class TestGenerateProposalsV2Op(OpTest):
|
||||
def set_data(self):
|
||||
self.init_test_params()
|
||||
self.init_test_input()
|
||||
self.init_test_output()
|
||||
self.inputs = {
|
||||
'Scores': self.scores,
|
||||
'BboxDeltas': self.bbox_deltas,
|
||||
'ImShape': self.im_shape.astype(np.float32),
|
||||
'Anchors': self.anchors,
|
||||
'Variances': self.variances
|
||||
}
|
||||
|
||||
self.attrs = {
|
||||
'pre_nms_topN': self.pre_nms_topN,
|
||||
'post_nms_topN': self.post_nms_topN,
|
||||
'nms_thresh': self.nms_thresh,
|
||||
'min_size': self.min_size,
|
||||
'eta': self.eta
|
||||
}
|
||||
|
||||
self.outputs = {
|
||||
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
|
||||
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "generate_proposals_v2"
|
||||
self.set_data()
|
||||
|
||||
def init_test_params(self):
|
||||
self.pre_nms_topN = 12000 # train 12000, test 2000
|
||||
self.post_nms_topN = 5000 # train 6000, test 1000
|
||||
self.nms_thresh = 0.7
|
||||
self.min_size = 3.0
|
||||
self.eta = 1.
|
||||
|
||||
def init_test_input(self):
|
||||
batch_size = 1
|
||||
input_channels = 20
|
||||
layer_h = 16
|
||||
layer_w = 16
|
||||
input_feat = np.random.random(
|
||||
(batch_size, input_channels, layer_h, layer_w)).astype('float32')
|
||||
self.anchors, self.variances = anchor_generator_in_python(
|
||||
input_feat=input_feat,
|
||||
anchor_sizes=[16., 32.],
|
||||
aspect_ratios=[0.5, 1.0],
|
||||
variances=[1.0, 1.0, 1.0, 1.0],
|
||||
stride=[16.0, 16.0],
|
||||
offset=0.5)
|
||||
self.im_shape = np.array([[64, 64]]).astype('float32')
|
||||
num_anchors = self.anchors.shape[2]
|
||||
self.scores = np.random.random(
|
||||
(batch_size, num_anchors, layer_h, layer_w)).astype('float32')
|
||||
self.bbox_deltas = np.random.random(
|
||||
(batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
|
||||
|
||||
def init_test_output(self):
|
||||
self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_v2_in_python(
|
||||
self.scores, self.bbox_deltas, self.im_shape, self.anchors,
|
||||
self.variances, self.pre_nms_topN, self.post_nms_topN,
|
||||
self.nms_thresh, self.min_size, self.eta)
|
||||
|
||||
|
||||
class TestGenerateProposalsV2OutLodOp(TestGenerateProposalsV2Op):
|
||||
def set_data(self):
|
||||
self.init_test_params()
|
||||
self.init_test_input()
|
||||
self.init_test_output()
|
||||
self.inputs = {
|
||||
'Scores': self.scores,
|
||||
'BboxDeltas': self.bbox_deltas,
|
||||
'ImShape': self.im_shape.astype(np.float32),
|
||||
'Anchors': self.anchors,
|
||||
'Variances': self.variances
|
||||
}
|
||||
|
||||
self.attrs = {
|
||||
'pre_nms_topN': self.pre_nms_topN,
|
||||
'post_nms_topN': self.post_nms_topN,
|
||||
'nms_thresh': self.nms_thresh,
|
||||
'min_size': self.min_size,
|
||||
'eta': self.eta,
|
||||
'return_rois_num': True
|
||||
}
|
||||
|
||||
self.outputs = {
|
||||
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
|
||||
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
|
||||
'RpnRoisNum': (np.asarray(
|
||||
self.rois_num, dtype=np.int32))
|
||||
}
|
||||
|
||||
|
||||
class TestGenerateProposalsV2OpNoBoxLeft(TestGenerateProposalsV2Op):
|
||||
def init_test_params(self):
|
||||
self.pre_nms_topN = 12000 # train 12000, test 2000
|
||||
self.post_nms_topN = 5000 # train 6000, test 1000
|
||||
self.nms_thresh = 0.7
|
||||
self.min_size = 1000.0
|
||||
self.eta = 1.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
paddle.enable_static()
|
||||
unittest.main()
|
Loading…
Reference in new issue