|
|
|
@ -25,45 +25,55 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
|
|
|
|
|
"Input(DetectRes) of DetectionMAPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"),
|
|
|
|
|
"Input(Label) of DetectionMAPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("AccumPosCount"),
|
|
|
|
|
"Output(AccumPosCount) of DetectionMAPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("AccumTruePos"),
|
|
|
|
|
"Output(AccumTruePos) of DetectionMAPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasOutput("AccumFalsePos"),
|
|
|
|
|
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
|
|
|
|
|
"Output(MAP) of DetectionMAPOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("DetectRes"), "Input", "DetectRes",
|
|
|
|
|
"DetectionMAP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Label"), "Input", "Label", "DetectionMAP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("AccumPosCount"), "Output", "AccumPosCount",
|
|
|
|
|
"DetectionMAP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("AccumTruePos"), "Output", "AccumTruePos",
|
|
|
|
|
"DetectionMAP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("AccumFalsePos"), "Output", "AccumFalsePos",
|
|
|
|
|
"DetectionMAP");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("MAP"), "Output", "MAP", "DetectionMAP");
|
|
|
|
|
|
|
|
|
|
auto det_dims = ctx->GetInputDim("DetectRes");
|
|
|
|
|
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
|
|
|
|
|
"The rank of Input(DetectRes) must be 2, "
|
|
|
|
|
"the shape is [N, 6].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
|
|
|
|
|
"The shape is of Input(DetectRes) [N, 6].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
det_dims.size(), 2UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(DetectRes) ndim must be 2, the shape is [N, 6],"
|
|
|
|
|
"but received the ndim is %d",
|
|
|
|
|
det_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
det_dims[1], 6UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape is of Input(DetectRes) [N, 6], but received"
|
|
|
|
|
" shape is [N, %d]",
|
|
|
|
|
det_dims[1]));
|
|
|
|
|
auto label_dims = ctx->GetInputDim("Label");
|
|
|
|
|
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
|
|
|
|
|
"The rank of Input(Label) must be 2, "
|
|
|
|
|
"the shape is [N, 6].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The ndim of Input(Label) must be 2, but received %d",
|
|
|
|
|
label_dims.size()));
|
|
|
|
|
if (ctx->IsRuntime() || label_dims[1] > 0) {
|
|
|
|
|
PADDLE_ENFORCE(label_dims[1] == 6 || label_dims[1] == 5,
|
|
|
|
|
"The shape of Input(Label) is [N, 6] or [N, 5].");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
(label_dims[1] == 6 || label_dims[1] == 5), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The shape of Input(Label) is [N, 6] or [N, 5], but received "
|
|
|
|
|
"[N, %d]",
|
|
|
|
|
label_dims[1]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ctx->HasInput("PosCount")) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("TruePos"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(TruePos) of DetectionMAPOp should not be null when "
|
|
|
|
|
"Input(TruePos) is not null.");
|
|
|
|
|
"Input(PosCount) is not null."));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("FalsePos"),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(FalsePos) of DetectionMAPOp should not be null when "
|
|
|
|
|
"Input(FalsePos) is not null.");
|
|
|
|
|
"Input(PosCount) is not null."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
|
|
|
|
@ -170,8 +180,10 @@ class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.SetDefault("integral")
|
|
|
|
|
.InEnum({"integral", "11point"})
|
|
|
|
|
.AddCustomChecker([](const std::string& ap_type) {
|
|
|
|
|
PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone,
|
|
|
|
|
"The ap_type should be 'integral' or '11point.");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
GetAPType(ap_type), APType::kNone,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The ap_type should be 'integral' or '11point."));
|
|
|
|
|
});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Detection mAP evaluate operator.
|
|
|
|
|