diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc
index dda423efd3..63895f8a1d 100644
--- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc
+++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc
@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(
         ctx->HasOutput("TargetBBox"),
         "Output(TargetBBox) of RpnTargetAssignOp should not be null");
+    PADDLE_ENFORCE(
+        ctx->HasOutput("BBox_inside_weight"),
+        "Output(BBox_inside_weight) of RpnTargetAssignOp should not be null");
 
     auto anchor_dims = ctx->GetInputDim("Anchor");
     auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
     ctx->SetOutputDim("ScoreIndex", {-1});
     ctx->SetOutputDim("TargetLabel", {-1, 1});
     ctx->SetOutputDim("TargetBBox", {-1, 4});
+    ctx->SetOutputDim("BBox_inside_weight", {-1, 4});
   }
 
  protected:
@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
                  const float rpn_positive_overlap,
                  const float rpn_negative_overlap, std::vector<int>* fg_inds,
                  std::vector<int>* bg_inds, std::vector<int>* tgt_lbl,
+                 std::vector<int>* fg_fake, std::vector<T>* bbox_inside_weight,
                  std::minstd_rand engine, bool use_random) {
   float epsilon = 0.00001;
   int anchor_num = anchor_to_gt_max.dims()[0];
@@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
   // Reservoir Sampling
   int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im);
   ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random);
-  fg_num = static_cast<int>(fg_inds_fake.size());
-  for (int64_t i = 0; i < fg_num; ++i) {
+  int fg_fake_num = static_cast<int>(fg_inds_fake.size());
+  for (int64_t i = 0; i < fg_fake_num; ++i) {
     target_label[fg_inds_fake[i]] = 1;
   }
 
-  int bg_num = rpn_batch_size_per_im - fg_num;
+  int bg_num = rpn_batch_size_per_im - fg_fake_num;
   for (int64_t i = 0; i < anchor_num; ++i) {
     if (anchor_to_gt_max_data[i] < rpn_negative_overlap) {
       bg_inds_fake.push_back(i);
@@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
   }
   ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random);
   bg_num = static_cast<int>(bg_inds_fake.size());
+  int fake_num = 0;
   for (int64_t i = 0; i < bg_num; ++i) {
+    // fg fake found
+    if (target_label[bg_inds_fake[i]] == 1) {
+      fake_num++;
+      fg_fake->emplace_back(fg_inds_fake[0]);
+      for (int j = 0; j < 4; ++j) {
+        bbox_inside_weight->emplace_back(T(0.));
+      }
+    }
     target_label[bg_inds_fake[i]] = 0;
   }
 
+  for (int64_t i = 0; i < (fg_fake_num - fake_num) * 4; ++i) {
+    bbox_inside_weight->emplace_back(T(1.));
+  }
+
   for (int64_t i = 0; i < anchor_num; ++i) {
-    if (target_label[i] == 1) fg_inds->emplace_back(i);
+    if (target_label[i] == 1) {
+      fg_inds->emplace_back(i);
+      fg_fake->emplace_back(i);
+    }
     if (target_label[i] == 0) bg_inds->emplace_back(i);
   }
   fg_num = fg_inds->size();
@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
   std::vector<int> bg_inds;
   std::vector<int> gt_inds;
   std::vector<int> tgt_lbl;
-
+  std::vector<int> fg_fake;
+  std::vector<T> bbox_inside_weight;
   // Calculate the max IoU between anchors and gt boxes
   // Map from anchor to gt box that has highest overlap
   auto place = ctx.GetPlace();
@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
   // Follow the Faster RCNN's implementation
   ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max,
               rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap,
-              rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, engine,
-              use_random);
+              rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, &fg_fake,
+              &bbox_inside_weight, engine, use_random);
 
   int fg_num = fg_inds.size();
   int bg_num = bg_inds.size();
-  gt_inds.reserve(fg_num);
-  for (int i = 0; i < fg_num; ++i) {
-    gt_inds.emplace_back(argmax[fg_inds[i]]);
+  int fg_fake_num = fg_fake.size();
+  gt_inds.reserve(fg_fake_num);
+  for (int i = 0; i < fg_fake_num; ++i) {
+    gt_inds.emplace_back(argmax[fg_fake[i]]);
   }
-
-  Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t;
-  int* loc_index_data = loc_index_t.mutable_data<int>({fg_num}, place);
+  Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t;
+  int* loc_index_data = loc_index_t.mutable_data<int>({fg_fake_num}, place);
   int* score_index_data =
       score_index_t.mutable_data<int>({fg_num + bg_num}, place);
   int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place);
-  int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_num}, place);
-  std::copy(fg_inds.begin(), fg_inds.end(), loc_index_data);
+  int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_fake_num}, place);
+  T* bbox_inside_weight_data =
+      bbox_inside_weight_t.mutable_data<T>({fg_fake_num, 4}, place);
+  std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data);
   std::copy(fg_inds.begin(), fg_inds.end(), score_index_data);
   std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num);
   std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data);
   std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data);
+  std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(),
+            bbox_inside_weight_data);
   std::vector<Tensor> loc_score_tgtlbl_gt;
   loc_score_tgtlbl_gt.emplace_back(loc_index_t);
   loc_score_tgtlbl_gt.emplace_back(score_index_t);
   loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t);
   loc_score_tgtlbl_gt.emplace_back(gt_inds_t);
+  loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t);
 
   return loc_score_tgtlbl_gt;
 }
@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
     auto* score_index = context.Output<LoDTensor>("ScoreIndex");
     auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
     auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
+    auto* bbox_inside_weight = context.Output<LoDTensor>("BBox_inside_weight");
 
     PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
                       "RpnTargetAssignOp gt_boxes needs 1 level of LoD");
@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
     score_index->mutable_data<int>({max_num}, place);
     tgt_bbox->mutable_data<T>({max_num, 4}, place);
     tgt_lbl->mutable_data<int>({max_num, 1}, place);
-
+    bbox_inside_weight->mutable_data<T>({max_num, 4}, place);
     auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
 
     std::random_device rnd;
@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
       Tensor sampled_score_index = loc_score_tgtlbl_gt[1];
       Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2];
       Tensor sampled_gt_index = loc_score_tgtlbl_gt[3];
+      Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4];
 
       int loc_num = sampled_loc_index.dims()[0];
       int score_num = sampled_score_index.dims()[0];
@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
       AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap);
       AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox);
       AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl);
+      AppendRpns<T>(bbox_inside_weight, total_loc_num * 4,
+                    &sampled_bbox_inside_weight);
       total_loc_num += loc_num;
 
       total_score_num += score_num;
@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
     score_index->set_lod(loc_score);
     tgt_bbox->set_lod(lod_loc);
     tgt_lbl->set_lod(loc_score);
+    bbox_inside_weight->set_lod(lod_loc);
     loc_index->Resize({total_loc_num});
     score_index->Resize({total_score_num});
     tgt_bbox->Resize({total_loc_num, 4});
     tgt_lbl->Resize({total_score_num, 1});
+    bbox_inside_weight->Resize({total_loc_num, 4});
   }
 };
 
@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
         "TargetLabel",
         "(Tensor<int>), The target labels of each anchor with shape "
         "[F + B, 1], F and B are sampled foreground and backgroud number.");
+    AddOutput("BBox_inside_weight",
+              "(Tensor), The bbox inside weight with shape "
+              "[F, 4], F is the sampled foreground number.");
     AddComment(R"DOC(
 This operator can be, for a given set of ground truth bboxes and the
 anchors, to assign classification and regression targets to each prediction.
diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py
index 1cfcbbb9c1..8026fa9398 100644
--- a/python/paddle/fluid/layers/detection.py
+++ b/python/paddle/fluid/layers/detection.py
@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred,
     Returns:
         tuple:
                A tuple(predicted_scores, predicted_location, target_label,
-               target_bbox) is returned. The predicted_scores and
-               predicted_location is the predicted result of the RPN.
+               target_bbox, bbox_inside_weight) is returned. The predicted_scores 
+               and predicted_location is the predicted result of the RPN.
                The target_label and target_bbox is the ground truth,
                respectively. The predicted_location is a 2D Tensor with shape
                [F, 4], and the shape of target_bbox is same as the shape of
@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred,
                [F + B, 1], and the shape of target_label is same as the shape
                of the predicted_scores, B is the number of the background
                anchors, the F and B is depends on the input of this operator.
+               Bbox_inside_weight represents whether the predicted loc is fake_fg
+               or not and the shape is [F, 4].
 
     Examples:
         .. code-block:: python
@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred,
                           append_batch_size=False, dtype='float32')
         gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
                          append_batch_size=False, dtype='float32')
-        loc_pred, score_pred, loc_target, score_target =
+        loc_pred, score_pred, loc_target, score_target, bbox_inside_weight =
             fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
                                           cls_logits=cls_logits,
                                           anchor_box=anchor_box,
@@ -151,6 +153,7 @@ def rpn_target_assign(bbox_pred,
     score_index = helper.create_tmp_variable(dtype='int32')
     target_label = helper.create_tmp_variable(dtype='int32')
     target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype)
+    bbox_inside_weight = helper.create_tmp_variable(dtype=anchor_box.dtype)
     helper.append_op(
         type="rpn_target_assign",
         inputs={
@@ -163,7 +166,8 @@ def rpn_target_assign(bbox_pred,
             'LocationIndex': loc_index,
             'ScoreIndex': score_index,
             'TargetLabel': target_label,
-            'TargetBBox': target_bbox
+            'TargetBBox': target_bbox,
+            'BBox_inside_weight': bbox_inside_weight
         },
         attrs={
             'rpn_batch_size_per_im': rpn_batch_size_per_im,
@@ -178,13 +182,14 @@ def rpn_target_assign(bbox_pred,
     score_index.stop_gradient = True
     target_label.stop_gradient = True
     target_bbox.stop_gradient = True
+    bbox_inside_weight.stop_gradient = True
 
     cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
     bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
     predicted_cls_logits = nn.gather(cls_logits, score_index)
     predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
 
-    return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
+    return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
 
 
 def detection_output(loc,
diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py
index 56129641ce..b36b4272c7 100644
--- a/python/paddle/fluid/tests/test_detection.py
+++ b/python/paddle/fluid/tests/test_detection.py
@@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase):
                 dtype='float32',
                 lod_level=1,
                 append_batch_size=False)
-            pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
+            pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign(
                 bbox_pred=bbox_pred,
                 cls_logits=cls_logits,
                 anchor_box=anchor_box,
@@ -313,12 +313,14 @@ class TestRpnTargetAssign(unittest.TestCase):
                 rpn_straddle_thresh=0.0,
                 rpn_fg_fraction=0.5,
                 rpn_positive_overlap=0.7,
-                rpn_negative_overlap=0.3)
+                rpn_negative_overlap=0.3,
+                use_random=False)
 
             self.assertIsNotNone(pred_scores)
             self.assertIsNotNone(pred_loc)
             self.assertIsNotNone(tgt_lbl)
             self.assertIsNotNone(tgt_bbox)
+            self.assertIsNotNone(bbox_inside_weight)
             assert pred_scores.shape[1] == 1
             assert pred_loc.shape[1] == 4
             assert pred_loc.shape[1] == tgt_bbox.shape[1]
diff --git a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py
index f63dbcd3d7..fe1fa5e54d 100644
--- a/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py
+++ b/python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py
@@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap,
             fg_inds, size=(len(fg_inds) - num_fg), replace=False)
     else:
         disable_inds = fg_inds[num_fg:]
+
     labels[disable_inds] = -1
     fg_inds = np.where(labels == 1)[0]
+    bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32)
 
     num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
     bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
@@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap,
         enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
     else:
         enable_inds = bg_inds[:num_bg]
+
+    fg_fake_inds = np.array([], np.int32)
+    fg_value = np.array([fg_inds[0]], np.int32)
+    fake_num = 0
+    for bg_id in enable_inds:
+        if bg_id in fg_inds:
+            fake_num += 1
+            fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
     labels[enable_inds] = 0
+
+    bbox_inside_weight[fake_num:, :] = 1
     fg_inds = np.where(labels == 1)[0]
     bg_inds = np.where(labels == 0)[0]
-
-    loc_index = fg_inds
-    score_index = np.hstack((fg_inds, bg_inds))
+    loc_index = np.hstack([fg_fake_inds, fg_inds])
+    score_index = np.hstack([fg_inds, bg_inds])
     labels = labels[score_index]
     assert not np.any(labels == -1), "Wrong labels with -1"
 
-    gt_inds = anchor_to_gt_argmax[fg_inds]
+    gt_inds = anchor_to_gt_argmax[loc_index]
 
-    return loc_index, score_index, labels, gt_inds
+    return loc_index, score_index, labels, gt_inds, bbox_inside_weight
 
 
 def get_anchor(n, c, h, w):
@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors,
         gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
         iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
 
-        loc_inds, score_inds, labels, gt_inds = rpn_target_assign(
-            iou, rpn_batch_size_per_im, rpn_positive_overlap,
-            rpn_negative_overlap, rpn_fg_fraction, use_random)
+        loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = \
+                         rpn_target_assign(iou, rpn_batch_size_per_im,
+                                           rpn_positive_overlap,
+                                           rpn_negative_overlap,
+                                           rpn_fg_fraction,
+                                           use_random)
         # unmap to all anchor 
         loc_inds = inds_inside[loc_inds]
         score_inds = inds_inside[score_inds]
@@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors,
             score_indexes = score_inds
             tgt_labels = labels
             tgt_bboxes = box_deltas
+            bbox_inside_weights = bbox_inside_weight
         else:
             loc_indexes = np.concatenate(
                 [loc_indexes, loc_inds + i * anchor_num])
@@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors,
                 [score_indexes, score_inds + i * anchor_num])
             tgt_labels = np.concatenate([tgt_labels, labels])
             tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
+            bbox_inside_weights = np.vstack([bbox_inside_weights, \
+                                             bbox_inside_weight])
 
-    return loc_indexes, score_indexes, tgt_bboxes, tgt_labels
+    return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights
 
 
 class TestRpnTargetAssignOp(OpTest):
@@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest):
         rpn_fg_fraction = 0.5
         use_random = False
 
-        loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python(
-            all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh,
-            rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap,
-            rpn_fg_fraction, use_random)
+        loc_index, score_index, tgt_bbox, labels, bbox_inside_weights = \
+            rpn_target_assign_in_python(all_anchors, gt_boxes, is_crowd,
+                                   im_info, lod, rpn_straddle_thresh,
+                                   rpn_batch_size_per_im, rpn_positive_overlap,
+                                   rpn_negative_overlap,
+                                   rpn_fg_fraction, use_random)
         labels = labels[:, np.newaxis]
 
         self.op_type = "rpn_target_assign"
@@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest):
             'LocationIndex': loc_index.astype('int32'),
             'ScoreIndex': score_index.astype('int32'),
             'TargetBBox': tgt_bbox.astype('float32'),
-            'TargetLabel': labels.astype('int32')
+            'TargetLabel': labels.astype('int32'),
+            'BBox_inside_weight': bbox_inside_weights.astype('float32')
         }
 
     def test_check_output(self):