|
|
@ -75,6 +75,11 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel {
|
|
|
|
ctx->SetOutputDim("MaskRois", {-1, 4});
|
|
|
|
ctx->SetOutputDim("MaskRois", {-1, 4});
|
|
|
|
ctx->SetOutputDim("RoiHasMaskInt32", {-1, 1});
|
|
|
|
ctx->SetOutputDim("RoiHasMaskInt32", {-1, 1});
|
|
|
|
ctx->SetOutputDim("MaskInt32", {-1, num_classes * resolution * resolution});
|
|
|
|
ctx->SetOutputDim("MaskInt32", {-1, num_classes * resolution * resolution});
|
|
|
|
|
|
|
|
if (!ctx->IsRuntime()) {
|
|
|
|
|
|
|
|
ctx->SetLoDLevel("MaskRois", ctx->GetLoDLevel("Rois"));
|
|
|
|
|
|
|
|
ctx->SetLoDLevel("RoiHasMaskInt32", ctx->GetLoDLevel("Rois"));
|
|
|
|
|
|
|
|
ctx->SetLoDLevel("MaskInt32", ctx->GetLoDLevel("Rois"));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|