Update generate_proposal_labels_op to support CascadeRCNN. (#17200)

* Update generate_proposal_labels_op to support CascadeRCNN.
revert-18229-add_multi_gpu_install_check
FDInSky 6 years ago committed by qingqing01
parent 9ed2f936f1
commit 9e4b9d9798

@ -351,7 +351,7 @@ paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits',
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)), ('document', '9c601df88b251f22e9311c52939948cd'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'c0d00acf724691ff3480d4207036a722'))
paddle.fluid.layers.generate_proposals (ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)), ('document', 'b7d707822b6af2a586bce608040235b1'))
paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None), ('document', 'b319b10ddaf17fb4ddf03518685a17ef'))
paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '72fca4a39ccf82d5c746ae62d1868a99'))

@ -2075,9 +2075,13 @@ def generate_proposal_labels(rpn_rois,
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=None,
use_random=True):
use_random=True,
is_cls_agnostic=False,
is_cascade_rcnn=False):
"""
** Generate Proposal Labels of Faster-RCNN **
This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth,
to sample foreground boxes and background boxes, and compute loss target.
@ -2108,6 +2112,8 @@ def generate_proposal_labels(rpn_rois,
bbox_reg_weights(list|tuple): Box regression weights.
class_nums(int): Class number.
use_random(bool): Use random sampling to choose foreground and background boxes.
is_cls_agnostic(bool): bbox regression use class agnostic simply which only represent fg and bg boxes.
is_cascade_rcnn(bool): it will filter some bbox crossing the image's boundary when setting True.
Examples:
.. code-block:: python
@ -2166,7 +2172,9 @@ def generate_proposal_labels(rpn_rois,
'bg_thresh_lo': bg_thresh_lo,
'bbox_reg_weights': bbox_reg_weights,
'class_nums': class_nums,
'use_random': use_random
'use_random': use_random,
'is_cls_agnostic': is_cls_agnostic,
'is_cascade_rcnn': is_cascade_rcnn
})
rois.stop_gradient = True

@ -22,10 +22,10 @@ import paddle.fluid as fluid
from op_test import OpTest
def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
im_info, batch_size_per_im, fg_fraction,
fg_thresh, bg_thresh_hi, bg_thresh_lo,
bbox_reg_weights, class_nums):
def generate_proposal_labels_in_python(
rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums, is_cls_agnostic, is_cascade_rcnn):
rois = []
labels_int32 = []
bbox_targets = []
@ -36,13 +36,12 @@ def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_info)):
frcn_blobs = _sample_rois(
rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i],
im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums)
frcn_blobs = _sample_rois(rpn_rois[im_i], gt_classes[im_i],
is_crowd[im_i], gt_boxes[im_i], im_info[im_i],
batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums, is_cls_agnostic, is_cascade_rcnn)
lod.append(frcn_blobs['rois'].shape[0])
rois.append(frcn_blobs['rois'])
labels_int32.append(frcn_blobs['labels_int32'])
bbox_targets.append(frcn_blobs['bbox_targets'])
@ -54,7 +53,8 @@ def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums):
bg_thresh_lo, bbox_reg_weights, class_nums, is_cls_agnostic,
is_cascade_rcnn):
rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
@ -62,7 +62,8 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
im_scale = im_info[2]
inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale
if is_cascade_rcnn:
rpn_rois = rpn_rois[gt_boxes.shape[0]:, :]
boxes = np.vstack([gt_boxes, rpn_rois])
gt_overlaps = np.zeros((boxes.shape[0], class_nums))
box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32)
@ -87,26 +88,37 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1)
# Foreground
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample foreground if there are too many
# if fg_inds.shape[0] > fg_rois_per_this_image:
# fg_inds = np.random.choice(
# fg_inds, size=fg_rois_per_this_image, replace=False)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
bg_thresh_lo))[0]
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0])
# Sample background if there are too many
# if bg_inds.shape[0] > bg_rois_per_this_image:
# bg_inds = np.random.choice(
# bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
# Cascade RCNN Decode Filter
if is_cascade_rcnn:
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws > 0) & (hs > 0))[0]
boxes = boxes[keep]
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
bg_thresh_lo))[0]
fg_rois_per_this_image = fg_inds.shape[0]
bg_rois_per_this_image = bg_inds.shape[0]
else:
# Foreground
fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample foreground if there are too many
if fg_inds.shape[0] > fg_rois_per_this_image:
fg_inds = np.random.choice(
fg_inds, size=fg_rois_per_this_image, replace=False)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
bg_thresh_lo))[0]
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0])
# Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image:
bg_inds = np.random.choice(
bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds]
@ -114,14 +126,12 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
sampled_boxes = boxes[keep_inds]
sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]]
sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0]
bbox_label_targets = _compute_targets(sampled_boxes, sampled_gts,
sampled_labels, bbox_reg_weights)
bbox_targets, bbox_inside_weights = _expand_bbox_targets(bbox_label_targets,
class_nums)
bbox_targets, bbox_inside_weights = _expand_bbox_targets(
bbox_label_targets, class_nums, is_cls_agnostic)
bbox_outside_weights = np.array(
bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype)
# Scale rois
sampled_rois = sampled_boxes * im_scale
@ -192,19 +202,22 @@ def _box_to_delta(ex_boxes, gt_boxes, weights):
return targets
def _expand_bbox_targets(bbox_targets_input, class_nums):
def _expand_bbox_targets(bbox_targets_input, class_nums, is_cls_agnostic):
class_labels = bbox_targets_input[:, 0]
fg_inds = np.where(class_labels > 0)[0]
bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums))
#if is_cls_agnostic:
# class_labels = [1 if ll > 0 else 0 for ll in class_labels]
# class_labels = np.array(class_labels, dtype=np.int32)
# class_nums = 2
bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums
if not is_cls_agnostic else 4 * 2))
bbox_inside_weights = np.zeros(bbox_targets.shape)
for ind in fg_inds:
class_label = int(class_labels[ind])
class_label = int(class_labels[ind]) if not is_cls_agnostic else 1
start_ind = class_label * 4
end_ind = class_label * 4 + 4
bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, 1:]
bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, 1.0)
return bbox_targets, bbox_inside_weights
@ -228,7 +241,9 @@ class TestGenerateProposalLabelsOp(OpTest):
'bg_thresh_lo': self.bg_thresh_lo,
'bbox_reg_weights': self.bbox_reg_weights,
'class_nums': self.class_nums,
'use_random': False
'use_random': False,
'is_cls_agnostic': self.is_cls_agnostic,
'is_cascade_rcnn': self.is_cascade_rcnn
}
self.outputs = {
'Rois': (self.rois, [self.lod]),
@ -252,12 +267,15 @@ class TestGenerateProposalLabelsOp(OpTest):
self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0
self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2]
self.class_nums = 81
#self.class_nums = 81
self.is_cls_agnostic = False #True
self.is_cascade_rcnn = True
self.class_nums = 2 if self.is_cls_agnostic else 81
def init_test_input(self):
np.random.seed(0)
gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = 2000 #self.batch_size_per_im - gt_nums
proposal_nums = 2000 if not self.is_cascade_rcnn else 512 #self.batch_size_per_im - gt_nums
images_shape = [[64, 64]]
self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)):
@ -280,7 +298,8 @@ class TestGenerateProposalLabelsOp(OpTest):
self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums
self.bbox_reg_weights, self.class_nums,
self.is_cls_agnostic, self.is_cascade_rcnn
)
self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32)

Loading…
Cancel
Save