diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 2976c3ff4c..7df03990b7 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -41,6 +41,11 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { } ctx->SetOutputsDim("MultiFpnRois", outs_dims); ctx->SetOutputDim("RestoreIndex", {-1, 1}); + if (!ctx->IsRuntime()) { + for (size_t i = 0; i < num_out_rois; ++i) { + ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i); + } + } } protected: diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 9cdc46b4a2..0cfb79b358 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -74,6 +74,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { } else { ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); } + if (!ctx->IsRuntime()) { + ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1)); + } } protected: @@ -493,6 +496,9 @@ class MultiClassNMS2Op : public MultiClassNMSOp { } else { ctx->SetOutputDim("Index", {-1, 1}); } + if (!ctx->IsRuntime()) { + ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1)); + } } }; diff --git a/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py b/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py index 39db9f5476..ee8202aa9f 100644 --- a/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py @@ -30,11 +30,8 @@ COMPILE_RUN_OP_WHITE_LIST = [ 'rpn_target_assign', \ 'retinanet_target_assign', \ 'filter_by_instag', \ - 'multiclass_nms', \ - 'multiclass_nms2', \ 'im2sequence', \ 'generate_proposal_labels', \ - 'distribute_fpn_proposals', \ 'detection_map', \ 'locality_aware_nms', \ 'var_conv_2d'