Enhence generate_proposal_labels_op and fix some bug. (#13239)

* Enhence generate_proposal_labels_op
* Fix bug in generate_proposals_op
* Refine rpn_target_assign_op.
* by Bu Xingyuan, Wang Guanzhong and Dang Qingqing
fix-develop-build.sh
Xingyuan Bu 7 years ago committed by qingqing01
parent 23ec966cd3
commit 9e2e893f59

@ -305,9 +305,9 @@ paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'neg
paddle.fluid.layers.detection_output ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0))
paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None))
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'anchor_var', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'gt_boxes', 'im_scales', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)

@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
@ -21,7 +22,7 @@ namespace operators {
*/
template <typename T>
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes, const T* weights,
const framework::Tensor& gt_boxes, const float* weights,
const bool normalized, framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
@ -62,5 +63,35 @@ void Gather(const T* in, const int in_stride, const int* index, const int num,
}
}
template <typename T>
void BboxOverlaps(const framework::Tensor& r_boxes,
const framework::Tensor& c_boxes,
framework::Tensor* overlaps) {
auto r_boxes_et = framework::EigenTensor<T, 2>::From(r_boxes);
auto c_boxes_et = framework::EigenTensor<T, 2>::From(c_boxes);
auto overlaps_et = framework::EigenTensor<T, 2>::From(*overlaps);
int r_num = r_boxes.dims()[0];
int c_num = c_boxes.dims()[0];
auto zero = static_cast<T>(0.0);
T r_box_area, c_box_area, x_min, y_min, x_max, y_max, inter_w, inter_h,
inter_area;
for (int i = 0; i < r_num; ++i) {
r_box_area = (r_boxes_et(i, 2) - r_boxes_et(i, 0) + 1) *
(r_boxes_et(i, 3) - r_boxes_et(i, 1) + 1);
for (int j = 0; j < c_num; ++j) {
c_box_area = (c_boxes_et(j, 2) - c_boxes_et(j, 0) + 1) *
(c_boxes_et(j, 3) - c_boxes_et(j, 1) + 1);
x_min = std::max(r_boxes_et(i, 0), c_boxes_et(j, 0));
y_min = std::max(r_boxes_et(i, 1), c_boxes_et(j, 1));
x_max = std::min(r_boxes_et(i, 2), c_boxes_et(j, 2));
y_max = std::min(r_boxes_et(i, 3), c_boxes_et(j, 3));
inter_w = std::max(x_max - x_min + 1, zero);
inter_h = std::max(y_max - y_min + 1, zero);
inter_area = inter_w * inter_h;
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
}
}
}
} // namespace operators
} // namespace paddle

@ -89,12 +89,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
}
for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len];
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1];
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2;
T anchor_center_y =
(anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0;
@ -106,25 +105,31 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y;
bbox_width = std::exp(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2]) *
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
anchor_width;
bbox_height = std::exp(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3]) *
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height;
} else {
bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width;
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height;
}
proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
}
// return proposals;
}
@ -156,18 +161,23 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
min_size *= im_info_data[2];
T im_scale = im_info_data[2];
keep->Resize({boxes->dims()[0], 1});
min_size = std::max(min_size, 1.0f);
int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T ws_origin_scale =
(boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
T hs_origin_scale =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] &&
y_ctr <= im_info_data[0]) {
if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i;
}
}
@ -218,8 +228,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1);
const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);

File diff suppressed because it is too large Load Diff

@ -55,15 +55,19 @@ for _OP in set(__auto__):
globals()[_OP] = generate_layer_fn(_OP)
def rpn_target_assign(loc,
scores,
def rpn_target_assign(bbox_pred,
cls_logits,
anchor_box,
anchor_var,
gt_box,
gt_boxes,
is_crowd,
im_info,
rpn_batch_size_per_im=256,
fg_fraction=0.25,
rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3):
rpn_negative_overlap=0.3,
use_random=True):
"""
** Target Assign Layer for region proposal network (RPN) in Faster-RCNN detection. **
@ -83,14 +87,13 @@ def rpn_target_assign(loc,
the positive anchors.
Args:
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the
bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax].
scores(Variable): A 3-D Tensor with shape [N, M, C] represents the
predicted confidence predictions. N is the batch size, C is the
class number, M is number of bounding boxes. For each category
there are total M scores which corresponding M bounding boxes.
cls_logits(Variable): A 3-D Tensor with shape [N, M, 1] represents the
predicted confidence predictions. N is the batch size, 1 is the
frontground and background sigmoid, M is number of bounding boxes.
anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box,
@ -99,11 +102,16 @@ def rpn_target_assign(loc,
coordinate of the anchor box.
anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded
variances of anchors.
gt_box (Variable): The ground-truth boudding boxes (bboxes) are a 2D
gt_boxes (Variable): The ground-truth boudding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input.
is_crowd (Variable): A 1-D LoDTensor which indicates groud-truth is crowd.
im_info (Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
3 is the height, width and scale.
rpn_batch_size_per_im(int): Total number of RPN examples per image.
fg_fraction(float): Target fraction of RoI minibatch that is labeled
rpn_straddle_thresh(float): Remove RPN anchors that go outside the image
by straddle_thresh pixels.
rpn_fg_fraction(float): Target fraction of RoI minibatch that is labeled
foreground (i.e. class > 0), 0-th class is background.
rpn_positive_overlap(float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be a positive
@ -129,45 +137,48 @@ def rpn_target_assign(loc,
Examples:
.. code-block:: python
loc = layers.data(name='location', shape=[2, 80],
bbox_pred = layers.data(name='bbox_pred', shape=[100, 4],
append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[2, 40],
cls_logits = layers.data(name='cls_logits', shape=[100, 1],
append_batch_size=False, dtype='float32')
anchor_box = layers.data(name='anchor_box', shape=[20, 4],
append_batch_size=False, dtype='float32')
gt_box = layers.data(name='gt_box', shape=[10, 4],
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target =
fluid.layers.detection_output(loc=location,
scores=scores,
fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
cls_logits=cls_logits,
anchor_box=anchor_box,
gt_box=gt_box)
gt_boxes=gt_boxes)
"""
helper = LayerHelper('rpn_target_assign', **locals())
# Compute overlaps between the prior boxes and the gt boxes overlaps
iou = iou_similarity(x=gt_box, y=anchor_box)
# Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype='int32')
score_index = helper.create_tmp_variable(dtype='int32')
target_label = helper.create_tmp_variable(dtype='int64')
target_label = helper.create_tmp_variable(dtype='int32')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
helper.append_op(
type="rpn_target_assign",
inputs={'Anchor': anchor_box,
'GtBox': gt_box,
'DistMat': iou},
inputs={
'Anchor': anchor_box,
'GtBoxes': gt_boxes,
'IsCrowd': is_crowd,
'ImInfo': im_info
},
outputs={
'LocationIndex': loc_index,
'ScoreIndex': score_index,
'TargetLabel': target_label,
'TargetBBox': target_bbox,
'TargetBBox': target_bbox
},
attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im,
'rpn_straddle_thresh': rpn_straddle_thresh,
'rpn_positive_overlap': rpn_positive_overlap,
'rpn_negative_overlap': rpn_negative_overlap,
'fg_fraction': fg_fraction
'rpn_fg_fraction': rpn_fg_fraction,
'use_random': use_random
})
loc_index.stop_gradient = True
@ -175,12 +186,12 @@ def rpn_target_assign(loc,
target_label.stop_gradient = True
target_bbox.stop_gradient = True
scores = nn.reshape(x=scores, shape=(-1, 1))
loc = nn.reshape(x=loc, shape=(-1, 4))
predicted_scores = nn.gather(scores, score_index)
predicted_location = nn.gather(loc, loc_index)
cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_scores, predicted_location, target_label, target_bbox
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
def detection_output(loc,
@ -1258,15 +1269,17 @@ def anchor_generator(input,
def generate_proposal_labels(rpn_rois,
gt_classes,
is_crowd,
gt_boxes,
im_scales,
im_info,
batch_size_per_im=256,
fg_fraction=0.25,
fg_thresh=0.25,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=None):
class_nums=None,
use_random=True):
"""
** Generate proposal labels Faster-RCNN **
TODO(buxingyuan): Add Document
@ -1285,8 +1298,9 @@ def generate_proposal_labels(rpn_rois,
inputs={
'RpnRois': rpn_rois,
'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtBoxes': gt_boxes,
'ImScales': im_scales
'ImInfo': im_info
},
outputs={
'Rois': rois,
@ -1302,7 +1316,8 @@ def generate_proposal_labels(rpn_rois,
'bg_thresh_hi': bg_thresh_hi,
'bg_thresh_lo': bg_thresh_lo,
'bbox_reg_weights': bbox_reg_weights,
'class_nums': class_nums
'class_nums': class_nums,
'use_random': use_random
})
rois.stop_gradient = True

@ -148,51 +148,60 @@ class TestAnchorGenerator(unittest.TestCase):
class TestGenerateProposalLabels(unittest.TestCase):
def test_generate_proposal_labels(self):
rpn_rois = layers.data(
name='rpn_rois',
shape=[4, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
gt_classes = layers.data(
name='gt_classes',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
gt_boxes = layers.data(
name='gt_boxes',
shape=[6, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
im_scales = layers.data(
name='im_scales',
shape=[1],
dtype='float32',
lod_level=1,
append_batch_size=False)
class_nums = 5
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
gt_boxes=gt_boxes,
im_scales=im_scales,
batch_size_per_im=2,
fg_fraction=0.5,
fg_thresh=0.5,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
program = Program()
with program_guard(program):
rpn_rois = layers.data(
name='rpn_rois',
shape=[4, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
gt_classes = layers.data(
name='gt_classes',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
is_crowd = layers.data(
name='is_crowd',
shape=[6],
dtype='int32',
lod_level=1,
append_batch_size=False)
gt_boxes = layers.data(
name='gt_boxes',
shape=[6, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
class_nums = 5
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
is_crowd=is_crowd,
gt_boxes=gt_boxes,
im_info=im_info,
batch_size_per_im=2,
fg_fraction=0.5,
fg_thresh=0.5,
bg_thresh_hi=0.5,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
class TestMultiBoxHead(unittest.TestCase):
@ -254,18 +263,18 @@ class TestRpnTargetAssign(unittest.TestCase):
def test_rpn_target_assign(self):
program = Program()
with program_guard(program):
loc_shape = [10, 50, 4]
score_shape = [10, 50, 2]
bbox_pred_shape = [10, 50, 4]
cls_logits_shape = [10, 50, 2]
anchor_shape = [50, 4]
loc = layers.data(
name='loc',
shape=loc_shape,
bbox_pred = layers.data(
name='bbox_pred',
shape=bbox_pred_shape,
append_batch_size=False,
dtype='float32')
scores = layers.data(
name='scores',
shape=score_shape,
cls_logits = layers.data(
name='cls_logits',
shape=cls_logits_shape,
append_batch_size=False,
dtype='float32')
anchor_box = layers.data(
@ -278,17 +287,31 @@ class TestRpnTargetAssign(unittest.TestCase):
shape=anchor_shape,
append_batch_size=False,
dtype='float32')
gt_box = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32')
gt_boxes = layers.data(
name='gt_boxes', shape=[4], lod_level=1, dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[10],
dtype='int32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
loc=loc,
scores=scores,
bbox_pred=bbox_pred,
cls_logits=cls_logits,
anchor_box=anchor_box,
anchor_var=anchor_var,
gt_box=gt_box,
gt_boxes=gt_boxes,
is_crowd=is_crowd,
im_info=im_info,
rpn_batch_size_per_im=256,
fg_fraction=0.25,
rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3)

@ -20,10 +20,10 @@ import paddle.fluid as fluid
from op_test import OpTest
def generate_proposal_labels_in_python(
rpn_rois, gt_classes, gt_boxes, im_scales, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums):
def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
im_info, batch_size_per_im, fg_fraction,
fg_thresh, bg_thresh_hi, bg_thresh_lo,
bbox_reg_weights, class_nums):
rois = []
labels_int32 = []
bbox_targets = []
@ -31,13 +31,13 @@ def generate_proposal_labels_in_python(
bbox_outside_weights = []
lod = []
assert len(rpn_rois) == len(
im_scales), 'batch size of rpn_rois and ground_truth is not matched'
im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_scales)):
for im_i in range(len(im_info)):
frcn_blobs = _sample_rois(
rpn_rois[im_i], gt_classes[im_i], gt_boxes[im_i], im_scales[im_i],
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums)
rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i],
im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums)
lod.append(frcn_blobs['rois'].shape[0])
@ -50,13 +50,14 @@ def generate_proposal_labels_in_python(
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod
def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo,
bbox_reg_weights, class_nums):
def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
# Roidb
im_scale = im_info[2]
inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale
@ -78,6 +79,9 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[
overlapped_boxes_ind]
crowd_ind = np.where(is_crowd)[0]
gt_overlaps[crowd_ind] = -1
max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1)
@ -85,9 +89,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample foreground if there are too many
if fg_inds.shape[0] > fg_rois_per_this_image:
fg_inds = np.random.choice(
fg_inds, size=fg_rois_per_this_image, replace=False)
# if fg_inds.shape[0] > fg_rois_per_this_image:
# fg_inds = np.random.choice(
# fg_inds, size=fg_rois_per_this_image, replace=False)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
@ -96,9 +101,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0])
# Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image:
bg_inds = np.random.choice(
bg_inds, size=bg_rois_per_this_image, replace=False)
# if bg_inds.shape[0] > bg_rois_per_this_image:
# bg_inds = np.random.choice(
# bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds]
@ -208,8 +214,9 @@ class TestGenerateProposalLabelsOp(OpTest):
self.inputs = {
'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod),
'GtClasses': (self.gt_classes[0], self.gts_lod),
'IsCrowd': (self.is_crowd[0], self.gts_lod),
'GtBoxes': (self.gt_boxes[0], self.gts_lod),
'ImScales': self.im_scales[0]
'ImInfo': self.im_info
}
self.attrs = {
'batch_size_per_im': self.batch_size_per_im,
@ -218,14 +225,15 @@ class TestGenerateProposalLabelsOp(OpTest):
'bg_thresh_hi': self.bg_thresh_hi,
'bg_thresh_lo': self.bg_thresh_lo,
'bbox_reg_weights': self.bbox_reg_weights,
'class_nums': self.class_nums
'class_nums': self.class_nums,
'use_random': False
}
self.outputs = {
'Rois': (self.rois[0], [self.lod]),
'LabelsInt32': (self.labels_int32[0], [self.lod]),
'BboxTargets': (self.bbox_targets[0], [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights[0], [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights[0], [self.lod]),
'Rois': (self.rois, [self.lod]),
'LabelsInt32': (self.labels_int32, [self.lod]),
'BboxTargets': (self.bbox_targets, [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights, [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights, [self.lod]),
}
def test_check_output(self):
@ -236,8 +244,8 @@ class TestGenerateProposalLabelsOp(OpTest):
self.set_data()
def init_test_params(self):
self.batch_size_per_im = 10
self.fg_fraction = 1.0
self.batch_size_per_im = 512
self.fg_fraction = 0.25
self.fg_thresh = 0.5
self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0
@ -246,14 +254,14 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_input(self):
np.random.seed(0)
image_nums = 1
gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = self.batch_size_per_im - gt_nums
images_shape = []
self.im_scales = []
for i in range(image_nums):
images_shape.append(np.random.randint(200, size=2))
self.im_scales.append(np.ones((1)).astype(np.float32))
proposal_nums = 2000 #self.batch_size_per_im - gt_nums
images_shape = [[64, 64]]
self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)):
self.im_info[i, 0] = images_shape[i][0]
self.im_info[i, 1] = images_shape[i][1]
self.im_info[i, 2] = 0.8 #scale
self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape,
proposal_nums)
@ -261,16 +269,23 @@ class TestGenerateProposalLabelsOp(OpTest):
images_shape, self.class_nums, gt_nums)
self.gt_classes = [gt['gt_classes'] for gt in ground_truth]
self.gt_boxes = [gt['boxes'] for gt in ground_truth]
self.is_crowd = [gt['is_crowd'] for gt in ground_truth]
def init_test_output(self):
self.rois, self.labels_int32, self.bbox_targets, \
self.bbox_inside_weights, self.bbox_outside_weights, \
self.lod = generate_proposal_labels_in_python(
self.rpn_rois, self.gt_classes, self.gt_boxes, self.im_scales,
self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums
)
self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32)
self.labels_int32 = self.labels_int32[:, np.newaxis]
self.bbox_targets = np.vstack(self.bbox_targets)
self.bbox_inside_weights = np.vstack(self.bbox_inside_weights)
self.bbox_outside_weights = np.vstack(self.bbox_outside_weights)
def _generate_proposals(images_shape, proposal_nums):
@ -280,7 +295,7 @@ def _generate_proposals(images_shape, proposal_nums):
for i, image_shape in enumerate(images_shape):
proposals = _generate_boxes(image_shape, proposal_nums)
rpn_rois.append(proposals)
num_proposals += len(proposals)
num_proposals = len(proposals)
rpn_rois_lod.append(num_proposals)
return rpn_rois, [rpn_rois_lod]
@ -294,7 +309,11 @@ def _generate_groundtruth(images_shape, class_nums, gt_nums):
gt_classes = np.random.randint(
low=1, high=class_nums, size=gt_nums).astype(np.int32)
gt_boxes = _generate_boxes(image_shape, gt_nums)
ground_truth.append(dict(gt_classes=gt_classes, boxes=gt_boxes))
is_crowd = np.zeros((gt_nums), dtype=np.int32)
is_crowd[0] = 1
ground_truth.append(
dict(
gt_classes=gt_classes, boxes=gt_boxes, is_crowd=is_crowd))
num_gts += len(gt_classes)
gts_lod.append(num_gts)
return ground_truth, [gts_lod]

@ -114,10 +114,10 @@ def box_coder(all_anchors, bbox_deltas, variances):
#anchor_loc: width, height, center_x, center_y
anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0]
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1]
anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2
anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + 1
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + 1
anchor_loc[:, 2] = all_anchors[:, 0] + 0.5 * anchor_loc[:, 0]
anchor_loc[:, 3] = all_anchors[:, 1] + 0.5 * anchor_loc[:, 1]
#predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height
pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
@ -127,23 +127,29 @@ def box_coder(all_anchors, bbox_deltas, variances):
i, 0] + anchor_loc[i, 2]
pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
i, 1] + anchor_loc[i, 3]
pred_bbox[i, 2] = math.exp(variances[i, 2] *
bbox_deltas[i, 2]) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(variances[i, 3] *
bbox_deltas[i, 3]) * anchor_loc[i, 1]
pred_bbox[i, 2] = math.exp(
min(variances[i, 2] * bbox_deltas[i, 2], math.log(
1000 / 16.0))) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(
min(variances[i, 3] * bbox_deltas[i, 3], math.log(
1000 / 16.0))) * anchor_loc[i, 1]
else:
for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
i, 2]
pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
i, 3]
pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0]
pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1]
pred_bbox[i, 2] = math.exp(
min(bbox_deltas[i, 2], math.log(1000 / 16.0))) * anchor_loc[i,
0]
pred_bbox[i, 3] = math.exp(
min(bbox_deltas[i, 3], math.log(1000 / 16.0))) * anchor_loc[i,
1]
proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 - 1
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 - 1
return proposals
@ -170,13 +176,16 @@ def filter_boxes(boxes, min_size, im_info):
"""Only keep boxes with both sides >= min_size and center within the image.
"""
# Scale min_size to match image scale
min_size *= im_info[2]
im_scale = im_info[2]
min_size = max(min_size, 1.0)
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
ws_orig_scale = (boxes[:, 2] - boxes[:, 0]) / im_scale + 1
hs_orig_scale = (boxes[:, 3] - boxes[:, 1]) / im_scale + 1
x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) &
(y_ctr < im_info[0]))[0]
keep = np.where((ws_orig_scale >= min_size) & (hs_orig_scale >= min_size) &
(x_ctr < im_info[1]) & (y_ctr < im_info[0]))[0]
return keep
@ -204,7 +213,7 @@ def iou(box_a, box_b):
xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)
inter_area = max(xb - xa + 1, 0.0) * max(yb - ya + 1, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area)
Loading…
Cancel
Save