Update generate_proposal_labels_op to support CascadeRCNN. (#17200)

* Update generate_proposal_labels_op to support CascadeRCNN.
revert-18229-add_multi_gpu_install_check
FDInSky 6 years ago committed by qingqing01
parent 9ed2f936f1
commit 9e4b9d9798

@ -351,7 +351,7 @@ paddle.fluid.layers.rpn_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits',
paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595')) paddle.fluid.layers.retinanet_target_assign (ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'gt_labels', 'is_crowd', 'im_info', 'num_classes', 'positive_overlap', 'negative_overlap'], varargs=None, keywords=None, defaults=(1, 0.5, 0.4)), ('document', 'fa1d1c9d5e0111684c0db705f86a2595'))
paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a')) paddle.fluid.layers.anchor_generator (ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)), ('document', '82b2aefeeb1b706bc4afec70928a259a'))
paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc')) paddle.fluid.layers.roi_perspective_transform (ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,)), ('document', 'd1ddc75629fedee46f82e631e22c79dc'))
paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True)), ('document', '9c601df88b251f22e9311c52939948cd')) paddle.fluid.layers.generate_proposal_labels (ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random', 'is_cls_agnostic', 'is_cascade_rcnn'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True, False, False)), ('document', 'c0d00acf724691ff3480d4207036a722'))
paddle.fluid.layers.generate_proposals (ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)), ('document', 'b7d707822b6af2a586bce608040235b1')) paddle.fluid.layers.generate_proposals (ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)), ('document', 'b7d707822b6af2a586bce608040235b1'))
paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None), ('document', 'b319b10ddaf17fb4ddf03518685a17ef')) paddle.fluid.layers.generate_mask_labels (ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None), ('document', 'b319b10ddaf17fb4ddf03518685a17ef'))
paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '72fca4a39ccf82d5c746ae62d1868a99')) paddle.fluid.layers.iou_similarity (ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '72fca4a39ccf82d5c746ae62d1868a99'))

@ -2075,9 +2075,13 @@ def generate_proposal_labels(rpn_rois,
bg_thresh_lo=0.0, bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=None, class_nums=None,
use_random=True): use_random=True,
is_cls_agnostic=False,
is_cascade_rcnn=False):
""" """
** Generate Proposal Labels of Faster-RCNN ** ** Generate Proposal Labels of Faster-RCNN **
This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth, This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth,
to sample foreground boxes and background boxes, and compute loss target. to sample foreground boxes and background boxes, and compute loss target.
@ -2108,6 +2112,8 @@ def generate_proposal_labels(rpn_rois,
bbox_reg_weights(list|tuple): Box regression weights. bbox_reg_weights(list|tuple): Box regression weights.
class_nums(int): Class number. class_nums(int): Class number.
use_random(bool): Use random sampling to choose foreground and background boxes. use_random(bool): Use random sampling to choose foreground and background boxes.
is_cls_agnostic(bool): bbox regression use class agnostic simply which only represent fg and bg boxes.
is_cascade_rcnn(bool): it will filter some bbox crossing the image's boundary when setting True.
Examples: Examples:
.. code-block:: python .. code-block:: python
@ -2166,7 +2172,9 @@ def generate_proposal_labels(rpn_rois,
'bg_thresh_lo': bg_thresh_lo, 'bg_thresh_lo': bg_thresh_lo,
'bbox_reg_weights': bbox_reg_weights, 'bbox_reg_weights': bbox_reg_weights,
'class_nums': class_nums, 'class_nums': class_nums,
'use_random': use_random 'use_random': use_random,
'is_cls_agnostic': is_cls_agnostic,
'is_cascade_rcnn': is_cascade_rcnn
}) })
rois.stop_gradient = True rois.stop_gradient = True

@ -22,10 +22,10 @@ import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes, def generate_proposal_labels_in_python(
im_info, batch_size_per_im, fg_fraction, rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, batch_size_per_im,
fg_thresh, bg_thresh_hi, bg_thresh_lo, fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
bbox_reg_weights, class_nums): class_nums, is_cls_agnostic, is_cascade_rcnn):
rois = [] rois = []
labels_int32 = [] labels_int32 = []
bbox_targets = [] bbox_targets = []
@ -36,13 +36,12 @@ def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
im_info), 'batch size of rpn_rois and ground_truth is not matched' im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_info)): for im_i in range(len(im_info)):
frcn_blobs = _sample_rois( frcn_blobs = _sample_rois(rpn_rois[im_i], gt_classes[im_i],
rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i], is_crowd[im_i], gt_boxes[im_i], im_info[im_i],
im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh, batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums) bg_thresh_hi, bg_thresh_lo, bbox_reg_weights,
class_nums, is_cls_agnostic, is_cascade_rcnn)
lod.append(frcn_blobs['rois'].shape[0]) lod.append(frcn_blobs['rois'].shape[0])
rois.append(frcn_blobs['rois']) rois.append(frcn_blobs['rois'])
labels_int32.append(frcn_blobs['labels_int32']) labels_int32.append(frcn_blobs['labels_int32'])
bbox_targets.append(frcn_blobs['bbox_targets']) bbox_targets.append(frcn_blobs['bbox_targets'])
@ -54,7 +53,8 @@ def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info, def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bg_thresh_lo, bbox_reg_weights, class_nums): bg_thresh_lo, bbox_reg_weights, class_nums, is_cls_agnostic,
is_cascade_rcnn):
rois_per_image = int(batch_size_per_im) rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
@ -62,7 +62,8 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
im_scale = im_info[2] im_scale = im_info[2]
inv_im_scale = 1. / im_scale inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale rpn_rois = rpn_rois * inv_im_scale
if is_cascade_rcnn:
rpn_rois = rpn_rois[gt_boxes.shape[0]:, :]
boxes = np.vstack([gt_boxes, rpn_rois]) boxes = np.vstack([gt_boxes, rpn_rois])
gt_overlaps = np.zeros((boxes.shape[0], class_nums)) gt_overlaps = np.zeros((boxes.shape[0], class_nums))
box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32) box_to_gt_ind_map = np.zeros((boxes.shape[0]), dtype=np.int32)
@ -87,26 +88,37 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
max_overlaps = gt_overlaps.max(axis=1) max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1) max_classes = gt_overlaps.argmax(axis=1)
# Foreground # Cascade RCNN Decode Filter
fg_inds = np.where(max_overlaps >= fg_thresh)[0] if is_cascade_rcnn:
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) ws = boxes[:, 2] - boxes[:, 0] + 1
# Sample foreground if there are too many hs = boxes[:, 3] - boxes[:, 1] + 1
# if fg_inds.shape[0] > fg_rois_per_this_image: keep = np.where((ws > 0) & (hs > 0))[0]
# fg_inds = np.random.choice( boxes = boxes[keep]
# fg_inds, size=fg_rois_per_this_image, replace=False) fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_inds = fg_inds[:fg_rois_per_this_image] bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
bg_thresh_lo))[0]
# Background fg_rois_per_this_image = fg_inds.shape[0]
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= bg_rois_per_this_image = bg_inds.shape[0]
bg_thresh_lo))[0] else:
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image # Foreground
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, fg_inds = np.where(max_overlaps >= fg_thresh)[0]
bg_inds.shape[0]) fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample background if there are too many # Sample foreground if there are too many
# if bg_inds.shape[0] > bg_rois_per_this_image: if fg_inds.shape[0] > fg_rois_per_this_image:
# bg_inds = np.random.choice( fg_inds = np.random.choice(
# bg_inds, size=bg_rois_per_this_image, replace=False) fg_inds, size=fg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image] fg_inds = fg_inds[:fg_rois_per_this_image]
# Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
bg_thresh_lo))[0]
bg_rois_per_this_image = rois_per_image - fg_rois_per_this_image
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0])
# Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image:
bg_inds = np.random.choice(
bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds) keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds] sampled_labels = max_classes[keep_inds]
@ -114,14 +126,12 @@ def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
sampled_boxes = boxes[keep_inds] sampled_boxes = boxes[keep_inds]
sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]] sampled_gts = gt_boxes[box_to_gt_ind_map[keep_inds]]
sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0] sampled_gts[fg_rois_per_this_image:, :] = gt_boxes[0]
bbox_label_targets = _compute_targets(sampled_boxes, sampled_gts, bbox_label_targets = _compute_targets(sampled_boxes, sampled_gts,
sampled_labels, bbox_reg_weights) sampled_labels, bbox_reg_weights)
bbox_targets, bbox_inside_weights = _expand_bbox_targets(bbox_label_targets, bbox_targets, bbox_inside_weights = _expand_bbox_targets(
class_nums) bbox_label_targets, class_nums, is_cls_agnostic)
bbox_outside_weights = np.array( bbox_outside_weights = np.array(
bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype) bbox_inside_weights > 0, dtype=bbox_inside_weights.dtype)
# Scale rois # Scale rois
sampled_rois = sampled_boxes * im_scale sampled_rois = sampled_boxes * im_scale
@ -192,19 +202,22 @@ def _box_to_delta(ex_boxes, gt_boxes, weights):
return targets return targets
def _expand_bbox_targets(bbox_targets_input, class_nums): def _expand_bbox_targets(bbox_targets_input, class_nums, is_cls_agnostic):
class_labels = bbox_targets_input[:, 0] class_labels = bbox_targets_input[:, 0]
fg_inds = np.where(class_labels > 0)[0] fg_inds = np.where(class_labels > 0)[0]
#if is_cls_agnostic:
bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums)) # class_labels = [1 if ll > 0 else 0 for ll in class_labels]
# class_labels = np.array(class_labels, dtype=np.int32)
# class_nums = 2
bbox_targets = np.zeros((class_labels.shape[0], 4 * class_nums
if not is_cls_agnostic else 4 * 2))
bbox_inside_weights = np.zeros(bbox_targets.shape) bbox_inside_weights = np.zeros(bbox_targets.shape)
for ind in fg_inds: for ind in fg_inds:
class_label = int(class_labels[ind]) class_label = int(class_labels[ind]) if not is_cls_agnostic else 1
start_ind = class_label * 4 start_ind = class_label * 4
end_ind = class_label * 4 + 4 end_ind = class_label * 4 + 4
bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, 1:] bbox_targets[ind, start_ind:end_ind] = bbox_targets_input[ind, 1:]
bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, 1.0) bbox_inside_weights[ind, start_ind:end_ind] = (1.0, 1.0, 1.0, 1.0)
return bbox_targets, bbox_inside_weights return bbox_targets, bbox_inside_weights
@ -228,7 +241,9 @@ class TestGenerateProposalLabelsOp(OpTest):
'bg_thresh_lo': self.bg_thresh_lo, 'bg_thresh_lo': self.bg_thresh_lo,
'bbox_reg_weights': self.bbox_reg_weights, 'bbox_reg_weights': self.bbox_reg_weights,
'class_nums': self.class_nums, 'class_nums': self.class_nums,
'use_random': False 'use_random': False,
'is_cls_agnostic': self.is_cls_agnostic,
'is_cascade_rcnn': self.is_cascade_rcnn
} }
self.outputs = { self.outputs = {
'Rois': (self.rois, [self.lod]), 'Rois': (self.rois, [self.lod]),
@ -252,12 +267,15 @@ class TestGenerateProposalLabelsOp(OpTest):
self.bg_thresh_hi = 0.5 self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0 self.bg_thresh_lo = 0.0
self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2] self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2]
self.class_nums = 81 #self.class_nums = 81
self.is_cls_agnostic = False #True
self.is_cascade_rcnn = True
self.class_nums = 2 if self.is_cls_agnostic else 81
def init_test_input(self): def init_test_input(self):
np.random.seed(0) np.random.seed(0)
gt_nums = 6 # Keep same with batch_size_per_im for unittest gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = 2000 #self.batch_size_per_im - gt_nums proposal_nums = 2000 if not self.is_cascade_rcnn else 512 #self.batch_size_per_im - gt_nums
images_shape = [[64, 64]] images_shape = [[64, 64]]
self.im_info = np.ones((len(images_shape), 3)).astype(np.float32) self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(len(images_shape)): for i in range(len(images_shape)):
@ -280,7 +298,8 @@ class TestGenerateProposalLabelsOp(OpTest):
self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info, self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction, self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo, self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums self.bbox_reg_weights, self.class_nums,
self.is_cls_agnostic, self.is_cascade_rcnn
) )
self.rois = np.vstack(self.rois) self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32) self.labels_int32 = np.hstack(self.labels_int32)

Loading…
Cancel
Save