|
|
|
@ -22,7 +22,6 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
// may contains weights and StatesInfo
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Predictions"),
|
|
|
|
|
"Input(Predictions) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Labels"),
|
|
|
|
@ -108,11 +107,54 @@ class PrecisionRecallOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"provided, current state will be accumulated to this state and "
|
|
|
|
|
"the accumulation state will be as the output state.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("BatchMetrics", "");
|
|
|
|
|
AddOutput("AccumMetrics", "");
|
|
|
|
|
AddOutput("AccumStatesInfo", "");
|
|
|
|
|
AddOutput("BatchMetrics",
|
|
|
|
|
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
|
|
|
|
|
"This output tensor contains metrics for current batch data."
|
|
|
|
|
"The layout is [macro average precision, macro average recall, "
|
|
|
|
|
"macro f1 score, micro average precision, micro average recall, "
|
|
|
|
|
"micro f1 score]");
|
|
|
|
|
AddOutput("AccumMetrics",
|
|
|
|
|
"(Tensor, default Tensor<float>), a 1-D tensor with shape {6}."
|
|
|
|
|
"This output tensor contains metrics for accumulated data."
|
|
|
|
|
"The layout is [macro average precision, macro average recall, "
|
|
|
|
|
"macro f1 score, micro average precision, micro average recall, "
|
|
|
|
|
"micro f1 score]");
|
|
|
|
|
AddOutput("AccumStatesInfo",
|
|
|
|
|
"(Tensor, default Tensor<float>), a 2-D tensor with shape D x 4, "
|
|
|
|
|
"where D is equal to class number. This output tensor contains "
|
|
|
|
|
"accumulated state variables used to compute metrics. The layout "
|
|
|
|
|
"for each class is [true positives, false positives, "
|
|
|
|
|
"true negatives, false negatives].");
|
|
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
When given 'Input(Predictions)' and 'Input(Labels)', this operator can be used
|
|
|
|
|
to compute various metrics including:
|
|
|
|
|
- macro average precision
|
|
|
|
|
- macro average recall
|
|
|
|
|
- macro f1 score
|
|
|
|
|
- micro average precision
|
|
|
|
|
- micro average recall
|
|
|
|
|
- micro f1 score
|
|
|
|
|
|
|
|
|
|
To compute the above metrics, we need to statistic counts for true positives,
|
|
|
|
|
false positives and false negatives. Here count of true negatives is not
|
|
|
|
|
necessary, but statisticing it may provide potential usage and the cost is
|
|
|
|
|
trivial, so the operator also provides count of true negatives.
|
|
|
|
|
|
|
|
|
|
We define state as a 2-D tensor with shape [class number, 4]. Each row of a
|
|
|
|
|
state contains statistic variables for corresponding class. Layout of each row
|
|
|
|
|
is: TP(true positives), FP(false positives), TN(true negatives),
|
|
|
|
|
FN(false negatives). If 'Input(Weights)' provided, TP, FP, TN, FN will be
|
|
|
|
|
calculated by given weight instead of instance count.
|
|
|
|
|
|
|
|
|
|
This operator also supports metrics computing for cross-batch situation. To
|
|
|
|
|
achieve this, 'Input(StatesInfo)' should be provided. State of current batch
|
|
|
|
|
data will be accumulated to 'Input(StatesInfo)' and 'Output(AccumStatesInfo)'
|
|
|
|
|
is the accumulation state.
|
|
|
|
|
|
|
|
|
|
'Output(BatchMetrics)' is metrics of current batch data while
|
|
|
|
|
'Output(AccumStatesInfo)' is metrics of accumulation data.
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|