fix lod level, test=develop (#22755)

revert-22710-feature/integrated_ps_api
wangguanzhong 5 years ago committed by GitHub
parent 79d712346f
commit f2d1cd119a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,11 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputsDim("MultiFpnRois", outs_dims); ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {-1, 1}); ctx->SetOutputDim("RestoreIndex", {-1, 1});
if (!ctx->IsRuntime()) {
for (size_t i = 0; i < num_out_rois; ++i) {
ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i);
}
}
} }
protected: protected:

@ -74,6 +74,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
} else { } else {
ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
} }
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1));
}
} }
protected: protected:
@ -493,6 +496,9 @@ class MultiClassNMS2Op : public MultiClassNMSOp {
} else { } else {
ctx->SetOutputDim("Index", {-1, 1}); ctx->SetOutputDim("Index", {-1, 1});
} }
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1));
}
} }
}; };

@ -30,11 +30,8 @@ COMPILE_RUN_OP_WHITE_LIST = [
'rpn_target_assign', \ 'rpn_target_assign', \
'retinanet_target_assign', \ 'retinanet_target_assign', \
'filter_by_instag', \ 'filter_by_instag', \
'multiclass_nms', \
'multiclass_nms2', \
'im2sequence', \ 'im2sequence', \
'generate_proposal_labels', \ 'generate_proposal_labels', \
'distribute_fpn_proposals', \
'detection_map', \ 'detection_map', \
'locality_aware_nms', \ 'locality_aware_nms', \
'var_conv_2d' 'var_conv_2d'

Loading…
Cancel
Save