Enhence generate_proposal_labels_op and fix some bug. (#13239)

* Enhence generate_proposal_labels_op
* Fix bug in generate_proposals_op
* Refine rpn_target_assign_op.
* by Bu Xingyuan, Wang Guanzhong and Dang Qingqing
fix-develop-build.sh
Xingyuan Bu 7 years ago committed by qingqing01
parent 23ec966cd3
commit 9e2e893f59

@ -305,9 +305,9 @@ paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'neg
paddle.fluid.layers.detection_output ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0)) paddle.fluid.layers.detection_output ArgSpec(args=['loc', 'scores', 'prior_box', 'prior_box_var', 'background_label', 'nms_threshold', 'nms_top_k', 'keep_top_k', 'score_threshold', 'nms_eta'], varargs=None, keywords=None, defaults=(0, 0.3, 400, 200, 0.01, 1.0))
paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None)) paddle.fluid.layers.ssd_loss ArgSpec(args=['location', 'confidence', 'gt_box', 'gt_label', 'prior_box', 'prior_box_var', 'background_label', 'overlap_threshold', 'neg_pos_ratio', 'neg_overlap', 'loc_loss_weight', 'conf_loss_weight', 'match_type', 'mining_type', 'normalize', 'sample_size'], varargs=None, keywords=None, defaults=(None, 0, 0.5, 3.0, 0.5, 1.0, 1.0, 'per_prediction', 'max_negative', True, None))
paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral')) paddle.fluid.layers.detection_map ArgSpec(args=['detect_res', 'label', 'class_num', 'background_label', 'overlap_threshold', 'evaluate_difficult', 'has_state', 'input_states', 'out_states', 'ap_version'], varargs=None, keywords=None, defaults=(0, 0.3, True, None, None, None, 'integral'))
paddle.fluid.layers.rpn_target_assign ArgSpec(args=['loc', 'scores', 'anchor_box', 'anchor_var', 'gt_box', 'rpn_batch_size_per_im', 'fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap'], varargs=None, keywords=None, defaults=(256, 0.25, 0.7, 0.3)) paddle.fluid.layers.rpn_target_assign ArgSpec(args=['bbox_pred', 'cls_logits', 'anchor_box', 'anchor_var', 'gt_boxes', 'is_crowd', 'im_info', 'rpn_batch_size_per_im', 'rpn_straddle_thresh', 'rpn_fg_fraction', 'rpn_positive_overlap', 'rpn_negative_overlap', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.0, 0.5, 0.7, 0.3, True))
paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None)) paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'aspect_ratios', 'variance', 'stride', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, [0.1, 0.1, 0.2, 0.2], None, 0.5, None))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'gt_boxes', 'im_scales', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None)) paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None)) paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None) paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)

@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
@ -21,7 +22,7 @@ namespace operators {
*/ */
template <typename T> template <typename T>
inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes, inline void BoxToDelta(const int box_num, const framework::Tensor& ex_boxes,
const framework::Tensor& gt_boxes, const T* weights, const framework::Tensor& gt_boxes, const float* weights,
const bool normalized, framework::Tensor* box_delta) { const bool normalized, framework::Tensor* box_delta) {
auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes); auto ex_boxes_et = framework::EigenTensor<T, 2>::From(ex_boxes);
auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes); auto gt_boxes_et = framework::EigenTensor<T, 2>::From(gt_boxes);
@ -62,5 +63,35 @@ void Gather(const T* in, const int in_stride, const int* index, const int num,
} }
} }
template <typename T>
void BboxOverlaps(const framework::Tensor& r_boxes,
const framework::Tensor& c_boxes,
framework::Tensor* overlaps) {
auto r_boxes_et = framework::EigenTensor<T, 2>::From(r_boxes);
auto c_boxes_et = framework::EigenTensor<T, 2>::From(c_boxes);
auto overlaps_et = framework::EigenTensor<T, 2>::From(*overlaps);
int r_num = r_boxes.dims()[0];
int c_num = c_boxes.dims()[0];
auto zero = static_cast<T>(0.0);
T r_box_area, c_box_area, x_min, y_min, x_max, y_max, inter_w, inter_h,
inter_area;
for (int i = 0; i < r_num; ++i) {
r_box_area = (r_boxes_et(i, 2) - r_boxes_et(i, 0) + 1) *
(r_boxes_et(i, 3) - r_boxes_et(i, 1) + 1);
for (int j = 0; j < c_num; ++j) {
c_box_area = (c_boxes_et(j, 2) - c_boxes_et(j, 0) + 1) *
(c_boxes_et(j, 3) - c_boxes_et(j, 1) + 1);
x_min = std::max(r_boxes_et(i, 0), c_boxes_et(j, 0));
y_min = std::max(r_boxes_et(i, 1), c_boxes_et(j, 1));
x_max = std::min(r_boxes_et(i, 2), c_boxes_et(j, 2));
y_max = std::min(r_boxes_et(i, 3), c_boxes_et(j, 3));
inter_w = std::max(x_max - x_min + 1, zero);
inter_h = std::max(y_max - y_min + 1, zero);
inter_area = inter_w * inter_h;
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -89,12 +89,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
} }
for (int64_t i = 0; i < row; ++i) { for (int64_t i = 0; i < row; ++i) {
T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len]; T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1]; T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
T anchor_center_x = (anchor_data[i * len + 2] + anchor_data[i * len]) / 2; T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
T anchor_center_y = T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
(anchor_data[i * len + 3] + anchor_data[i * len + 1]) / 2;
T bbox_center_x = 0, bbox_center_y = 0; T bbox_center_x = 0, bbox_center_y = 0;
T bbox_width = 0, bbox_height = 0; T bbox_width = 0, bbox_height = 0;
@ -106,25 +105,31 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y = variances_data[i * len + 1] * bbox_center_y = variances_data[i * len + 1] *
bbox_deltas_data[i * len + 1] * anchor_height + bbox_deltas_data[i * len + 1] * anchor_height +
anchor_center_y; anchor_center_y;
bbox_width = std::exp(variances_data[i * len + 2] * bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2]) * bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) *
anchor_width; anchor_width;
bbox_height = std::exp(variances_data[i * len + 3] * bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3]) * bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height; anchor_height;
} else { } else {
bbox_center_x = bbox_center_x =
bbox_deltas_data[i * len] * anchor_width + anchor_center_x; bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
bbox_center_y = bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(bbox_deltas_data[i * len + 2]) * anchor_width; bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
bbox_height = std::exp(bbox_deltas_data[i * len + 3]) * anchor_height; std::log(1000.0 / 16.0))) *
anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) *
anchor_height;
} }
proposals_data[i * len] = bbox_center_x - bbox_width / 2; proposals_data[i * len] = bbox_center_x - bbox_width / 2;
proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2; proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2; proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2; proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
} }
// return proposals; // return proposals;
} }
@ -156,18 +161,23 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
float min_size, const Tensor &im_info, Tensor *keep) { float min_size, const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>(); const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace()); T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
min_size *= im_info_data[2]; T im_scale = im_info_data[2];
keep->Resize({boxes->dims()[0], 1}); keep->Resize({boxes->dims()[0], 1});
min_size = std::max(min_size, 1.0f);
int *keep_data = keep->mutable_data<int>(ctx.GetPlace()); int *keep_data = keep->mutable_data<int>(ctx.GetPlace());
int keep_len = 0; int keep_len = 0;
for (int i = 0; i < boxes->dims()[0]; ++i) { for (int i = 0; i < boxes->dims()[0]; ++i) {
T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1; T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1; T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
T ws_origin_scale =
(boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
T hs_origin_scale =
(boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
T x_ctr = boxes_data[4 * i] + ws / 2; T x_ctr = boxes_data[4 * i] + ws / 2;
T y_ctr = boxes_data[4 * i + 1] + hs / 2; T y_ctr = boxes_data[4 * i + 1] + hs / 2;
if (ws >= min_size && hs >= min_size && x_ctr <= im_info_data[1] && if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
y_ctr <= im_info_data[0]) { x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
keep_data[keep_len++] = i; keep_data[keep_len++] = i;
} }
} }
@ -218,8 +228,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]); const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]); const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]); const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin; const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1);
const T inter_h = inter_ymax - inter_ymin; const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h; const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized); const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized); const T bbox2_area = BBoxArea<T>(box2, normalized);

File diff suppressed because it is too large Load Diff

@ -55,15 +55,19 @@ for _OP in set(__auto__):
globals()[_OP] = generate_layer_fn(_OP) globals()[_OP] = generate_layer_fn(_OP)
def rpn_target_assign(loc, def rpn_target_assign(bbox_pred,
scores, cls_logits,
anchor_box, anchor_box,
anchor_var, anchor_var,
gt_box, gt_boxes,
is_crowd,
im_info,
rpn_batch_size_per_im=256, rpn_batch_size_per_im=256,
fg_fraction=0.25, rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3): rpn_negative_overlap=0.3,
use_random=True):
""" """
** Target Assign Layer for region proposal network (RPN) in Faster-RCNN detection. ** ** Target Assign Layer for region proposal network (RPN) in Faster-RCNN detection. **
@ -83,14 +87,13 @@ def rpn_target_assign(loc,
the positive anchors. the positive anchors.
Args: Args:
loc(Variable): A 3-D Tensor with shape [N, M, 4] represents the bbox_pred(Variable): A 3-D Tensor with shape [N, M, 4] represents the
predicted locations of M bounding bboxes. N is the batch size, predicted locations of M bounding bboxes. N is the batch size,
and each bounding box has four coordinate values and the layout and each bounding box has four coordinate values and the layout
is [xmin, ymin, xmax, ymax]. is [xmin, ymin, xmax, ymax].
scores(Variable): A 3-D Tensor with shape [N, M, C] represents the cls_logits(Variable): A 3-D Tensor with shape [N, M, 1] represents the
predicted confidence predictions. N is the batch size, C is the predicted confidence predictions. N is the batch size, 1 is the
class number, M is number of bounding boxes. For each category frontground and background sigmoid, M is number of bounding boxes.
there are total M scores which corresponding M bounding boxes.
anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes, anchor_box(Variable): A 2-D Tensor with shape [M, 4] holds M boxes,
each box is represented as [xmin, ymin, xmax, ymax], each box is represented as [xmin, ymin, xmax, ymax],
[xmin, ymin] is the left top coordinate of the anchor box, [xmin, ymin] is the left top coordinate of the anchor box,
@ -99,11 +102,16 @@ def rpn_target_assign(loc,
coordinate of the anchor box. coordinate of the anchor box.
anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded anchor_var(Variable): A 2-D Tensor with shape [M,4] holds expanded
variances of anchors. variances of anchors.
gt_box (Variable): The ground-truth boudding boxes (bboxes) are a 2D gt_boxes (Variable): The ground-truth boudding boxes (bboxes) are a 2D
LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth LoDTensor with shape [Ng, 4], Ng is the total number of ground-truth
bboxes of mini-batch input. bboxes of mini-batch input.
is_crowd (Variable): A 1-D LoDTensor which indicates groud-truth is crowd.
im_info (Variable): A 2-D LoDTensor with shape [N, 3]. N is the batch size,
3 is the height, width and scale.
rpn_batch_size_per_im(int): Total number of RPN examples per image. rpn_batch_size_per_im(int): Total number of RPN examples per image.
fg_fraction(float): Target fraction of RoI minibatch that is labeled rpn_straddle_thresh(float): Remove RPN anchors that go outside the image
by straddle_thresh pixels.
rpn_fg_fraction(float): Target fraction of RoI minibatch that is labeled
foreground (i.e. class > 0), 0-th class is background. foreground (i.e. class > 0), 0-th class is background.
rpn_positive_overlap(float): Minimum overlap required between an anchor rpn_positive_overlap(float): Minimum overlap required between an anchor
and ground-truth box for the (anchor, gt box) pair to be a positive and ground-truth box for the (anchor, gt box) pair to be a positive
@ -129,45 +137,48 @@ def rpn_target_assign(loc,
Examples: Examples:
.. code-block:: python .. code-block:: python
loc = layers.data(name='location', shape=[2, 80], bbox_pred = layers.data(name='bbox_pred', shape=[100, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
scores = layers.data(name='scores', shape=[2, 40], cls_logits = layers.data(name='cls_logits', shape=[100, 1],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
anchor_box = layers.data(name='anchor_box', shape=[20, 4], anchor_box = layers.data(name='anchor_box', shape=[20, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
gt_box = layers.data(name='gt_box', shape=[10, 4], gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target = loc_pred, score_pred, loc_target, score_target =
fluid.layers.detection_output(loc=location, fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
scores=scores, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
gt_box=gt_box) gt_boxes=gt_boxes)
""" """
helper = LayerHelper('rpn_target_assign', **locals()) helper = LayerHelper('rpn_target_assign', **locals())
# Compute overlaps between the prior boxes and the gt boxes overlaps
iou = iou_similarity(x=gt_box, y=anchor_box)
# Assign target label to anchors # Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype='int32') loc_index = helper.create_tmp_variable(dtype='int32')
score_index = helper.create_tmp_variable(dtype='int32') score_index = helper.create_tmp_variable(dtype='int32')
target_label = helper.create_tmp_variable(dtype='int64') target_label = helper.create_tmp_variable(dtype='int32')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype) target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
helper.append_op( helper.append_op(
type="rpn_target_assign", type="rpn_target_assign",
inputs={'Anchor': anchor_box, inputs={
'GtBox': gt_box, 'Anchor': anchor_box,
'DistMat': iou}, 'GtBoxes': gt_boxes,
'IsCrowd': is_crowd,
'ImInfo': im_info
},
outputs={ outputs={
'LocationIndex': loc_index, 'LocationIndex': loc_index,
'ScoreIndex': score_index, 'ScoreIndex': score_index,
'TargetLabel': target_label, 'TargetLabel': target_label,
'TargetBBox': target_bbox, 'TargetBBox': target_bbox
}, },
attrs={ attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
'rpn_straddle_thresh': rpn_straddle_thresh,
'rpn_positive_overlap': rpn_positive_overlap, 'rpn_positive_overlap': rpn_positive_overlap,
'rpn_negative_overlap': rpn_negative_overlap, 'rpn_negative_overlap': rpn_negative_overlap,
'fg_fraction': fg_fraction 'rpn_fg_fraction': rpn_fg_fraction,
'use_random': use_random
}) })
loc_index.stop_gradient = True loc_index.stop_gradient = True
@ -175,12 +186,12 @@ def rpn_target_assign(loc,
target_label.stop_gradient = True target_label.stop_gradient = True
target_bbox.stop_gradient = True target_bbox.stop_gradient = True
scores = nn.reshape(x=scores, shape=(-1, 1)) cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
loc = nn.reshape(x=loc, shape=(-1, 4)) bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_scores = nn.gather(scores, score_index) predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_location = nn.gather(loc, loc_index) predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_scores, predicted_location, target_label, target_bbox return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
def detection_output(loc, def detection_output(loc,
@ -1258,15 +1269,17 @@ def anchor_generator(input,
def generate_proposal_labels(rpn_rois, def generate_proposal_labels(rpn_rois,
gt_classes, gt_classes,
is_crowd,
gt_boxes, gt_boxes,
im_scales, im_info,
batch_size_per_im=256, batch_size_per_im=256,
fg_fraction=0.25, fg_fraction=0.25,
fg_thresh=0.25, fg_thresh=0.25,
bg_thresh_hi=0.5, bg_thresh_hi=0.5,
bg_thresh_lo=0.0, bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=None): class_nums=None,
use_random=True):
""" """
** Generate proposal labels Faster-RCNN ** ** Generate proposal labels Faster-RCNN **
TODO(buxingyuan): Add Document TODO(buxingyuan): Add Document
@ -1285,8 +1298,9 @@ def generate_proposal_labels(rpn_rois,
inputs={ inputs={
'RpnRois': rpn_rois, 'RpnRois': rpn_rois,
'GtClasses': gt_classes, 'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtBoxes': gt_boxes, 'GtBoxes': gt_boxes,
'ImScales': im_scales 'ImInfo': im_info
}, },
outputs={ outputs={
'Rois': rois, 'Rois': rois,
@ -1302,7 +1316,8 @@ def generate_proposal_labels(rpn_rois,
'bg_thresh_hi': bg_thresh_hi, 'bg_thresh_hi': bg_thresh_hi,
'bg_thresh_lo': bg_thresh_lo, 'bg_thresh_lo': bg_thresh_lo,
'bbox_reg_weights': bbox_reg_weights, 'bbox_reg_weights': bbox_reg_weights,
'class_nums': class_nums 'class_nums': class_nums,
'use_random': use_random
}) })
rois.stop_gradient = True rois.stop_gradient = True

@ -148,51 +148,60 @@ class TestAnchorGenerator(unittest.TestCase):
class TestGenerateProposalLabels(unittest.TestCase): class TestGenerateProposalLabels(unittest.TestCase):
def test_generate_proposal_labels(self): def test_generate_proposal_labels(self):
rpn_rois = layers.data( program = Program()
name='rpn_rois', with program_guard(program):
shape=[4, 4], rpn_rois = layers.data(
dtype='float32', name='rpn_rois',
lod_level=1, shape=[4, 4],
append_batch_size=False) dtype='float32',
gt_classes = layers.data( lod_level=1,
name='gt_classes', append_batch_size=False)
shape=[6], gt_classes = layers.data(
dtype='int32', name='gt_classes',
lod_level=1, shape=[6],
append_batch_size=False) dtype='int32',
gt_boxes = layers.data( lod_level=1,
name='gt_boxes', append_batch_size=False)
shape=[6, 4], is_crowd = layers.data(
dtype='float32', name='is_crowd',
lod_level=1, shape=[6],
append_batch_size=False) dtype='int32',
im_scales = layers.data( lod_level=1,
name='im_scales', append_batch_size=False)
shape=[1], gt_boxes = layers.data(
dtype='float32', name='gt_boxes',
lod_level=1, shape=[6, 4],
append_batch_size=False) dtype='float32',
class_nums = 5 lod_level=1,
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels( append_batch_size=False)
rpn_rois=rpn_rois, im_info = layers.data(
gt_classes=gt_classes, name='im_info',
gt_boxes=gt_boxes, shape=[1, 3],
im_scales=im_scales, dtype='float32',
batch_size_per_im=2, lod_level=1,
fg_fraction=0.5, append_batch_size=False)
fg_thresh=0.5, class_nums = 5
bg_thresh_hi=0.5, rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
bg_thresh_lo=0.0, rpn_rois=rpn_rois,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2], gt_classes=gt_classes,
class_nums=class_nums) is_crowd=is_crowd,
assert rois.shape[1] == 4 gt_boxes=gt_boxes,
assert rois.shape[0] == labels_int32.shape[0] im_info=im_info,
assert rois.shape[0] == bbox_targets.shape[0] batch_size_per_im=2,
assert rois.shape[0] == bbox_inside_weights.shape[0] fg_fraction=0.5,
assert rois.shape[0] == bbox_outside_weights.shape[0] fg_thresh=0.5,
assert bbox_targets.shape[1] == 4 * class_nums bg_thresh_hi=0.5,
assert bbox_inside_weights.shape[1] == 4 * class_nums bg_thresh_lo=0.0,
assert bbox_outside_weights.shape[1] == 4 * class_nums bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
assert rois.shape[0] == bbox_inside_weights.shape[0]
assert rois.shape[0] == bbox_outside_weights.shape[0]
assert bbox_targets.shape[1] == 4 * class_nums
assert bbox_inside_weights.shape[1] == 4 * class_nums
assert bbox_outside_weights.shape[1] == 4 * class_nums
class TestMultiBoxHead(unittest.TestCase): class TestMultiBoxHead(unittest.TestCase):
@ -254,18 +263,18 @@ class TestRpnTargetAssign(unittest.TestCase):
def test_rpn_target_assign(self): def test_rpn_target_assign(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
loc_shape = [10, 50, 4] bbox_pred_shape = [10, 50, 4]
score_shape = [10, 50, 2] cls_logits_shape = [10, 50, 2]
anchor_shape = [50, 4] anchor_shape = [50, 4]
loc = layers.data( bbox_pred = layers.data(
name='loc', name='bbox_pred',
shape=loc_shape, shape=bbox_pred_shape,
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
scores = layers.data( cls_logits = layers.data(
name='scores', name='cls_logits',
shape=score_shape, shape=cls_logits_shape,
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
anchor_box = layers.data( anchor_box = layers.data(
@ -278,17 +287,31 @@ class TestRpnTargetAssign(unittest.TestCase):
shape=anchor_shape, shape=anchor_shape,
append_batch_size=False, append_batch_size=False,
dtype='float32') dtype='float32')
gt_box = layers.data( gt_boxes = layers.data(
name='gt_box', shape=[4], lod_level=1, dtype='float32') name='gt_boxes', shape=[4], lod_level=1, dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[10],
dtype='int32',
lod_level=1,
append_batch_size=False)
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
loc=loc, bbox_pred=bbox_pred,
scores=scores, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
anchor_var=anchor_var, anchor_var=anchor_var,
gt_box=gt_box, gt_boxes=gt_boxes,
is_crowd=is_crowd,
im_info=im_info,
rpn_batch_size_per_im=256, rpn_batch_size_per_im=256,
fg_fraction=0.25, rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3) rpn_negative_overlap=0.3)

@ -20,10 +20,10 @@ import paddle.fluid as fluid
from op_test import OpTest from op_test import OpTest
def generate_proposal_labels_in_python( def generate_proposal_labels_in_python(rpn_rois, gt_classes, is_crowd, gt_boxes,
rpn_rois, gt_classes, gt_boxes, im_scales, batch_size_per_im, im_info, batch_size_per_im, fg_fraction,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, fg_thresh, bg_thresh_hi, bg_thresh_lo,
class_nums): bbox_reg_weights, class_nums):
rois = [] rois = []
labels_int32 = [] labels_int32 = []
bbox_targets = [] bbox_targets = []
@ -31,13 +31,13 @@ def generate_proposal_labels_in_python(
bbox_outside_weights = [] bbox_outside_weights = []
lod = [] lod = []
assert len(rpn_rois) == len( assert len(rpn_rois) == len(
im_scales), 'batch size of rpn_rois and ground_truth is not matched' im_info), 'batch size of rpn_rois and ground_truth is not matched'
for im_i in range(len(im_scales)): for im_i in range(len(im_info)):
frcn_blobs = _sample_rois( frcn_blobs = _sample_rois(
rpn_rois[im_i], gt_classes[im_i], gt_boxes[im_i], im_scales[im_i], rpn_rois[im_i], gt_classes[im_i], is_crowd[im_i], gt_boxes[im_i],
batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi, im_info[im_i], batch_size_per_im, fg_fraction, fg_thresh,
bg_thresh_lo, bbox_reg_weights, class_nums) bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums)
lod.append(frcn_blobs['rois'].shape[0]) lod.append(frcn_blobs['rois'].shape[0])
@ -50,13 +50,14 @@ def generate_proposal_labels_in_python(
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights, lod
def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im, def _sample_rois(rpn_rois, gt_classes, is_crowd, gt_boxes, im_info,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, batch_size_per_im, fg_fraction, fg_thresh, bg_thresh_hi,
bbox_reg_weights, class_nums): bg_thresh_lo, bbox_reg_weights, class_nums):
rois_per_image = int(batch_size_per_im) rois_per_image = int(batch_size_per_im)
fg_rois_per_im = int(np.round(fg_fraction * rois_per_image)) fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
# Roidb # Roidb
im_scale = im_info[2]
inv_im_scale = 1. / im_scale inv_im_scale = 1. / im_scale
rpn_rois = rpn_rois * inv_im_scale rpn_rois = rpn_rois * inv_im_scale
@ -78,6 +79,9 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[ box_to_gt_ind_map[overlapped_boxes_ind] = overlaps_argmax[
overlapped_boxes_ind] overlapped_boxes_ind]
crowd_ind = np.where(is_crowd)[0]
gt_overlaps[crowd_ind] = -1
max_overlaps = gt_overlaps.max(axis=1) max_overlaps = gt_overlaps.max(axis=1)
max_classes = gt_overlaps.argmax(axis=1) max_classes = gt_overlaps.argmax(axis=1)
@ -85,9 +89,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
fg_inds = np.where(max_overlaps >= fg_thresh)[0] fg_inds = np.where(max_overlaps >= fg_thresh)[0]
fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0]) fg_rois_per_this_image = np.minimum(fg_rois_per_im, fg_inds.shape[0])
# Sample foreground if there are too many # Sample foreground if there are too many
if fg_inds.shape[0] > fg_rois_per_this_image: # if fg_inds.shape[0] > fg_rois_per_this_image:
fg_inds = np.random.choice( # fg_inds = np.random.choice(
fg_inds, size=fg_rois_per_this_image, replace=False) # fg_inds, size=fg_rois_per_this_image, replace=False)
fg_inds = fg_inds[:fg_rois_per_this_image]
# Background # Background
bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >= bg_inds = np.where((max_overlaps < bg_thresh_hi) & (max_overlaps >=
@ -96,9 +101,10 @@ def _sample_rois(rpn_rois, gt_classes, gt_boxes, im_scale, batch_size_per_im,
bg_rois_per_this_image = np.minimum(bg_rois_per_this_image, bg_rois_per_this_image = np.minimum(bg_rois_per_this_image,
bg_inds.shape[0]) bg_inds.shape[0])
# Sample background if there are too many # Sample background if there are too many
if bg_inds.shape[0] > bg_rois_per_this_image: # if bg_inds.shape[0] > bg_rois_per_this_image:
bg_inds = np.random.choice( # bg_inds = np.random.choice(
bg_inds, size=bg_rois_per_this_image, replace=False) # bg_inds, size=bg_rois_per_this_image, replace=False)
bg_inds = bg_inds[:bg_rois_per_this_image]
keep_inds = np.append(fg_inds, bg_inds) keep_inds = np.append(fg_inds, bg_inds)
sampled_labels = max_classes[keep_inds] sampled_labels = max_classes[keep_inds]
@ -208,8 +214,9 @@ class TestGenerateProposalLabelsOp(OpTest):
self.inputs = { self.inputs = {
'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod), 'RpnRois': (self.rpn_rois[0], self.rpn_rois_lod),
'GtClasses': (self.gt_classes[0], self.gts_lod), 'GtClasses': (self.gt_classes[0], self.gts_lod),
'IsCrowd': (self.is_crowd[0], self.gts_lod),
'GtBoxes': (self.gt_boxes[0], self.gts_lod), 'GtBoxes': (self.gt_boxes[0], self.gts_lod),
'ImScales': self.im_scales[0] 'ImInfo': self.im_info
} }
self.attrs = { self.attrs = {
'batch_size_per_im': self.batch_size_per_im, 'batch_size_per_im': self.batch_size_per_im,
@ -218,14 +225,15 @@ class TestGenerateProposalLabelsOp(OpTest):
'bg_thresh_hi': self.bg_thresh_hi, 'bg_thresh_hi': self.bg_thresh_hi,
'bg_thresh_lo': self.bg_thresh_lo, 'bg_thresh_lo': self.bg_thresh_lo,
'bbox_reg_weights': self.bbox_reg_weights, 'bbox_reg_weights': self.bbox_reg_weights,
'class_nums': self.class_nums 'class_nums': self.class_nums,
'use_random': False
} }
self.outputs = { self.outputs = {
'Rois': (self.rois[0], [self.lod]), 'Rois': (self.rois, [self.lod]),
'LabelsInt32': (self.labels_int32[0], [self.lod]), 'LabelsInt32': (self.labels_int32, [self.lod]),
'BboxTargets': (self.bbox_targets[0], [self.lod]), 'BboxTargets': (self.bbox_targets, [self.lod]),
'BboxInsideWeights': (self.bbox_inside_weights[0], [self.lod]), 'BboxInsideWeights': (self.bbox_inside_weights, [self.lod]),
'BboxOutsideWeights': (self.bbox_outside_weights[0], [self.lod]), 'BboxOutsideWeights': (self.bbox_outside_weights, [self.lod]),
} }
def test_check_output(self): def test_check_output(self):
@ -236,8 +244,8 @@ class TestGenerateProposalLabelsOp(OpTest):
self.set_data() self.set_data()
def init_test_params(self): def init_test_params(self):
self.batch_size_per_im = 10 self.batch_size_per_im = 512
self.fg_fraction = 1.0 self.fg_fraction = 0.25
self.fg_thresh = 0.5 self.fg_thresh = 0.5
self.bg_thresh_hi = 0.5 self.bg_thresh_hi = 0.5
self.bg_thresh_lo = 0.0 self.bg_thresh_lo = 0.0
@ -246,14 +254,14 @@ class TestGenerateProposalLabelsOp(OpTest):
def init_test_input(self): def init_test_input(self):
np.random.seed(0) np.random.seed(0)
image_nums = 1
gt_nums = 6 # Keep same with batch_size_per_im for unittest gt_nums = 6 # Keep same with batch_size_per_im for unittest
proposal_nums = self.batch_size_per_im - gt_nums proposal_nums = 2000 #self.batch_size_per_im - gt_nums
images_shape = [] images_shape = [[64, 64]]
self.im_scales = [] self.im_info = np.ones((len(images_shape), 3)).astype(np.float32)
for i in range(image_nums): for i in range(len(images_shape)):
images_shape.append(np.random.randint(200, size=2)) self.im_info[i, 0] = images_shape[i][0]
self.im_scales.append(np.ones((1)).astype(np.float32)) self.im_info[i, 1] = images_shape[i][1]
self.im_info[i, 2] = 0.8 #scale
self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape, self.rpn_rois, self.rpn_rois_lod = _generate_proposals(images_shape,
proposal_nums) proposal_nums)
@ -261,16 +269,23 @@ class TestGenerateProposalLabelsOp(OpTest):
images_shape, self.class_nums, gt_nums) images_shape, self.class_nums, gt_nums)
self.gt_classes = [gt['gt_classes'] for gt in ground_truth] self.gt_classes = [gt['gt_classes'] for gt in ground_truth]
self.gt_boxes = [gt['boxes'] for gt in ground_truth] self.gt_boxes = [gt['boxes'] for gt in ground_truth]
self.is_crowd = [gt['is_crowd'] for gt in ground_truth]
def init_test_output(self): def init_test_output(self):
self.rois, self.labels_int32, self.bbox_targets, \ self.rois, self.labels_int32, self.bbox_targets, \
self.bbox_inside_weights, self.bbox_outside_weights, \ self.bbox_inside_weights, self.bbox_outside_weights, \
self.lod = generate_proposal_labels_in_python( self.lod = generate_proposal_labels_in_python(
self.rpn_rois, self.gt_classes, self.gt_boxes, self.im_scales, self.rpn_rois, self.gt_classes, self.is_crowd, self.gt_boxes, self.im_info,
self.batch_size_per_im, self.fg_fraction, self.batch_size_per_im, self.fg_fraction,
self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo, self.fg_thresh, self.bg_thresh_hi, self.bg_thresh_lo,
self.bbox_reg_weights, self.class_nums self.bbox_reg_weights, self.class_nums
) )
self.rois = np.vstack(self.rois)
self.labels_int32 = np.hstack(self.labels_int32)
self.labels_int32 = self.labels_int32[:, np.newaxis]
self.bbox_targets = np.vstack(self.bbox_targets)
self.bbox_inside_weights = np.vstack(self.bbox_inside_weights)
self.bbox_outside_weights = np.vstack(self.bbox_outside_weights)
def _generate_proposals(images_shape, proposal_nums): def _generate_proposals(images_shape, proposal_nums):
@ -280,7 +295,7 @@ def _generate_proposals(images_shape, proposal_nums):
for i, image_shape in enumerate(images_shape): for i, image_shape in enumerate(images_shape):
proposals = _generate_boxes(image_shape, proposal_nums) proposals = _generate_boxes(image_shape, proposal_nums)
rpn_rois.append(proposals) rpn_rois.append(proposals)
num_proposals += len(proposals) num_proposals = len(proposals)
rpn_rois_lod.append(num_proposals) rpn_rois_lod.append(num_proposals)
return rpn_rois, [rpn_rois_lod] return rpn_rois, [rpn_rois_lod]
@ -294,7 +309,11 @@ def _generate_groundtruth(images_shape, class_nums, gt_nums):
gt_classes = np.random.randint( gt_classes = np.random.randint(
low=1, high=class_nums, size=gt_nums).astype(np.int32) low=1, high=class_nums, size=gt_nums).astype(np.int32)
gt_boxes = _generate_boxes(image_shape, gt_nums) gt_boxes = _generate_boxes(image_shape, gt_nums)
ground_truth.append(dict(gt_classes=gt_classes, boxes=gt_boxes)) is_crowd = np.zeros((gt_nums), dtype=np.int32)
is_crowd[0] = 1
ground_truth.append(
dict(
gt_classes=gt_classes, boxes=gt_boxes, is_crowd=is_crowd))
num_gts += len(gt_classes) num_gts += len(gt_classes)
gts_lod.append(num_gts) gts_lod.append(num_gts)
return ground_truth, [gts_lod] return ground_truth, [gts_lod]

@ -114,10 +114,10 @@ def box_coder(all_anchors, bbox_deltas, variances):
#anchor_loc: width, height, center_x, center_y #anchor_loc: width, height, center_x, center_y
anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32) anchor_loc = np.zeros_like(bbox_deltas, dtype=np.float32)
anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] anchor_loc[:, 0] = all_anchors[:, 2] - all_anchors[:, 0] + 1
anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] anchor_loc[:, 1] = all_anchors[:, 3] - all_anchors[:, 1] + 1
anchor_loc[:, 2] = (all_anchors[:, 2] + all_anchors[:, 0]) / 2 anchor_loc[:, 2] = all_anchors[:, 0] + 0.5 * anchor_loc[:, 0]
anchor_loc[:, 3] = (all_anchors[:, 3] + all_anchors[:, 1]) / 2 anchor_loc[:, 3] = all_anchors[:, 1] + 0.5 * anchor_loc[:, 1]
#predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height #predicted bbox: bbox_center_x, bbox_center_y, bbox_width, bbox_height
pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32) pred_bbox = np.zeros_like(bbox_deltas, dtype=np.float32)
@ -127,23 +127,29 @@ def box_coder(all_anchors, bbox_deltas, variances):
i, 0] + anchor_loc[i, 2] i, 0] + anchor_loc[i, 2]
pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[ pred_bbox[i, 1] = variances[i, 1] * bbox_deltas[i, 1] * anchor_loc[
i, 1] + anchor_loc[i, 3] i, 1] + anchor_loc[i, 3]
pred_bbox[i, 2] = math.exp(variances[i, 2] * pred_bbox[i, 2] = math.exp(
bbox_deltas[i, 2]) * anchor_loc[i, 0] min(variances[i, 2] * bbox_deltas[i, 2], math.log(
pred_bbox[i, 3] = math.exp(variances[i, 3] * 1000 / 16.0))) * anchor_loc[i, 0]
bbox_deltas[i, 3]) * anchor_loc[i, 1] pred_bbox[i, 3] = math.exp(
min(variances[i, 3] * bbox_deltas[i, 3], math.log(
1000 / 16.0))) * anchor_loc[i, 1]
else: else:
for i in range(bbox_deltas.shape[0]): for i in range(bbox_deltas.shape[0]):
pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[ pred_bbox[i, 0] = bbox_deltas[i, 0] * anchor_loc[i, 0] + anchor_loc[
i, 2] i, 2]
pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[ pred_bbox[i, 1] = bbox_deltas[i, 1] * anchor_loc[i, 1] + anchor_loc[
i, 3] i, 3]
pred_bbox[i, 2] = math.exp(bbox_deltas[i, 2]) * anchor_loc[i, 0] pred_bbox[i, 2] = math.exp(
pred_bbox[i, 3] = math.exp(bbox_deltas[i, 3]) * anchor_loc[i, 1] min(bbox_deltas[i, 2], math.log(1000 / 16.0))) * anchor_loc[i,
0]
pred_bbox[i, 3] = math.exp(
min(bbox_deltas[i, 3], math.log(1000 / 16.0))) * anchor_loc[i,
1]
proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2 proposals[:, 0] = pred_bbox[:, 0] - pred_bbox[:, 2] / 2
proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2 proposals[:, 1] = pred_bbox[:, 1] - pred_bbox[:, 3] / 2
proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 proposals[:, 2] = pred_bbox[:, 0] + pred_bbox[:, 2] / 2 - 1
proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 proposals[:, 3] = pred_bbox[:, 1] + pred_bbox[:, 3] / 2 - 1
return proposals return proposals
@ -170,13 +176,16 @@ def filter_boxes(boxes, min_size, im_info):
"""Only keep boxes with both sides >= min_size and center within the image. """Only keep boxes with both sides >= min_size and center within the image.
""" """
# Scale min_size to match image scale # Scale min_size to match image scale
min_size *= im_info[2] im_scale = im_info[2]
min_size = max(min_size, 1.0)
ws = boxes[:, 2] - boxes[:, 0] + 1 ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1 hs = boxes[:, 3] - boxes[:, 1] + 1
ws_orig_scale = (boxes[:, 2] - boxes[:, 0]) / im_scale + 1
hs_orig_scale = (boxes[:, 3] - boxes[:, 1]) / im_scale + 1
x_ctr = boxes[:, 0] + ws / 2. x_ctr = boxes[:, 0] + ws / 2.
y_ctr = boxes[:, 1] + hs / 2. y_ctr = boxes[:, 1] + hs / 2.
keep = np.where((ws >= min_size) & (hs >= min_size) & (x_ctr < im_info[1]) & keep = np.where((ws_orig_scale >= min_size) & (hs_orig_scale >= min_size) &
(y_ctr < im_info[0]))[0] (x_ctr < im_info[1]) & (y_ctr < im_info[0]))[0]
return keep return keep
@ -204,7 +213,7 @@ def iou(box_a, box_b):
xb = min(xmax_a, xmax_b) xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b) yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) inter_area = max(xb - xa + 1, 0.0) * max(yb - ya + 1, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area) iou_ratio = inter_area / (area_a + area_b - inter_area)
Loading…
Cancel
Save