refine code

release/1.1
jerrywgz 7 years ago
parent 1c591c3909
commit e0708e62ba

@ -53,8 +53,8 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->HasOutput("TargetBBox"),
"Output(TargetBBox) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BBox_inside_weight"),
"Output(BBox_inside_weight) of RpnTargetAssignOp should not be null");
ctx->HasOutput("BBoxInsideWeight"),
"Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null");
auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
@ -71,7 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4});
ctx->SetOutputDim("BBox_inside_weight", {-1, 4});
ctx->SetOutputDim("BBoxInsideWeight", {-1, 4});
}
protected:
@ -345,7 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
auto* score_index = context.Output<LoDTensor>("ScoreIndex");
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBox_inside_weight");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"RpnTargetAssignOp gt_boxes needs 1 level of LoD");
@ -547,7 +547,7 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number.");
AddOutput("BBox_inside_weight",
AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number.");
AddComment(R"DOC(

@ -167,7 +167,7 @@ def rpn_target_assign(bbox_pred,
'ScoreIndex': score_index,
'TargetLabel': target_label,
'TargetBBox': target_bbox,
'BBox_inside_weight': bbox_inside_weight
'BBoxInsideWeight': bbox_inside_weight
},
attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im,

@ -324,6 +324,7 @@ class TestRpnTargetAssign(unittest.TestCase):
assert pred_scores.shape[1] == 1
assert pred_loc.shape[1] == 4
assert pred_loc.shape[1] == tgt_bbox.shape[1]
print(str(program))
class TestGenerateProposals(unittest.TestCase):

@ -227,7 +227,7 @@ class TestRpnTargetAssignOp(OpTest):
'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32'),
'BBox_inside_weight': bbox_inside_weights.astype('float32')
'BBoxInsideWeight': bbox_inside_weights.astype('float32')
}
def test_check_output(self):

Loading…
Cancel
Save