|
|
|
@ -647,33 +647,24 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
|
|
|
|
|
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
|
|
|
|
|
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
|
|
|
|
|
bool containMacroMicroInfo = getStatsInfo(&precision,
|
|
|
|
|
&recall,
|
|
|
|
|
&f1,
|
|
|
|
|
¯oAvgPrecision,
|
|
|
|
|
¯oAvgRecall,
|
|
|
|
|
¯oAvgF1Score,
|
|
|
|
|
µAvgPrecision,
|
|
|
|
|
µAvgRecall,
|
|
|
|
|
µAvgF1Score);
|
|
|
|
|
PrintStatsInfo info;
|
|
|
|
|
bool containMacroMicroInfo = getStatsInfo(&info);
|
|
|
|
|
os << "positive_label=" << config_.positive_label()
|
|
|
|
|
<< " precision=" << precision << " recall=" << recall
|
|
|
|
|
<< " F1-score=" << f1;
|
|
|
|
|
<< " precision=" << info.precision << " recall=" << info.recall
|
|
|
|
|
<< " F1-score=" << info.f1;
|
|
|
|
|
if (containMacroMicroInfo) {
|
|
|
|
|
os << "macro-average-precision=" << macroAvgPrecision
|
|
|
|
|
<< " macro-average-recall=" << macroAvgRecall
|
|
|
|
|
<< " macro-average-F1-score=" << macroAvgF1Score;
|
|
|
|
|
os << "macro-average-precision=" << info.macroAvgPrecision
|
|
|
|
|
<< " macro-average-recall=" << info.macroAvgRecall
|
|
|
|
|
<< " macro-average-F1-score=" << info.macroAvgF1Score;
|
|
|
|
|
if (!isMultiBinaryLabel_) {
|
|
|
|
|
// precision and recall are equal in this case
|
|
|
|
|
os << " micro-average-precision=" << microAvgPrecision;
|
|
|
|
|
os << " micro-average-precision=" << info.microAvgPrecision;
|
|
|
|
|
} else {
|
|
|
|
|
os << " micro-average-precision=" << microAvgPrecision
|
|
|
|
|
<< " micro-average-recall=" << microAvgRecall
|
|
|
|
|
<< " micro-average-F1-score=" << microAvgF1Score;
|
|
|
|
|
os << " micro-average-precision=" << info.microAvgPrecision
|
|
|
|
|
<< " micro-average-recall=" << info.microAvgRecall
|
|
|
|
|
<< " micro-average-F1-score=" << info.microAvgF1Score;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
|
|
|
|
@ -756,31 +747,22 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
|
|
|
|
|
|
|
|
|
|
void PrecisionRecallEvaluator::storeLocalValues() const {
|
|
|
|
|
if (this->values_.size() == 0) {
|
|
|
|
|
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
|
|
|
|
|
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
|
|
|
|
|
bool containMacroMicroInfo = getStatsInfo(&precision,
|
|
|
|
|
&recall,
|
|
|
|
|
&f1,
|
|
|
|
|
¯oAvgPrecision,
|
|
|
|
|
¯oAvgRecall,
|
|
|
|
|
¯oAvgF1Score,
|
|
|
|
|
µAvgPrecision,
|
|
|
|
|
µAvgRecall,
|
|
|
|
|
µAvgF1Score);
|
|
|
|
|
values_["precision"] = precision;
|
|
|
|
|
values_["recal"] = recall;
|
|
|
|
|
values_["F1-score"] = f1;
|
|
|
|
|
PrintStatsInfo info;
|
|
|
|
|
bool containMacroMicroInfo = getStatsInfo(&info);
|
|
|
|
|
values_["precision"] = info.precision;
|
|
|
|
|
values_["recal"] = info.recall;
|
|
|
|
|
values_["F1-score"] = info.f1;
|
|
|
|
|
if (containMacroMicroInfo) {
|
|
|
|
|
values_["macro-average-precision"] = macroAvgPrecision;
|
|
|
|
|
values_["macro-average-recall"] = macroAvgRecall;
|
|
|
|
|
values_["macro-average-F1-score"] = macroAvgF1Score;
|
|
|
|
|
values_["macro-average-precision"] = info.macroAvgPrecision;
|
|
|
|
|
values_["macro-average-recall"] = info.macroAvgRecall;
|
|
|
|
|
values_["macro-average-F1-score"] = info.macroAvgF1Score;
|
|
|
|
|
if (!isMultiBinaryLabel_) {
|
|
|
|
|
// precision and recall are equal in this case
|
|
|
|
|
values_["micro-average-precision"] = microAvgPrecision;
|
|
|
|
|
values_["micro-average-precision"] = info.microAvgPrecision;
|
|
|
|
|
} else {
|
|
|
|
|
values_["micro-average-precision"] = microAvgPrecision;
|
|
|
|
|
values_["micro-average-recall"] = microAvgRecall;
|
|
|
|
|
values_["micro-average-F1-score"] = microAvgF1Score;
|
|
|
|
|
values_["micro-average-precision"] = info.microAvgPrecision;
|
|
|
|
|
values_["micro-average-recall"] = info.microAvgRecall;
|
|
|
|
|
values_["micro-average-F1-score"] = info.microAvgF1Score;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -836,23 +818,16 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
|
|
|
|
|
delete[] buf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
|
|
|
|
|
double* recall,
|
|
|
|
|
double* f1,
|
|
|
|
|
double* macroAvgPrecision,
|
|
|
|
|
double* macroAvgRecall,
|
|
|
|
|
double* macroAvgF1Score,
|
|
|
|
|
double* microAvgPrecision,
|
|
|
|
|
double* microAvgRecall,
|
|
|
|
|
double* microAvgF1Score) const {
|
|
|
|
|
bool PrecisionRecallEvaluator::getStatsInfo(
|
|
|
|
|
PrecisionRecallEvaluator::PrintStatsInfo* info) const {
|
|
|
|
|
int label = config_.positive_label();
|
|
|
|
|
if (label != -1) {
|
|
|
|
|
CHECK(label >= 0 && label < (int)statsInfo_.size())
|
|
|
|
|
<< "positive_label [" << label << "] should be in range [0, "
|
|
|
|
|
<< statsInfo_.size() << ")";
|
|
|
|
|
*precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
|
|
|
|
|
*recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
|
|
|
|
|
*f1 = calcF1Score(*precision, *recall);
|
|
|
|
|
info->precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
|
|
|
|
|
info->recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
|
|
|
|
|
info->f1 = calcF1Score(info->precision, info->recall);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -861,23 +836,26 @@ bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
|
|
|
|
|
double microTotalTP = 0;
|
|
|
|
|
double microTotalFP = 0;
|
|
|
|
|
double microTotalFN = 0;
|
|
|
|
|
*macroAvgPrecision = 0;
|
|
|
|
|
*macroAvgRecall = 0;
|
|
|
|
|
info->macroAvgPrecision = 0;
|
|
|
|
|
info->macroAvgRecall = 0;
|
|
|
|
|
size_t numLabels = statsInfo_.size();
|
|
|
|
|
for (size_t i = 0; i < numLabels; ++i) {
|
|
|
|
|
microTotalTP += statsInfo_[i].TP;
|
|
|
|
|
microTotalFP += statsInfo_[i].FP;
|
|
|
|
|
microTotalFN += statsInfo_[i].FN;
|
|
|
|
|
*macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
|
|
|
|
|
*macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
|
|
|
|
|
info->macroAvgPrecision +=
|
|
|
|
|
calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
|
|
|
|
|
info->macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
|
|
|
|
|
}
|
|
|
|
|
*macroAvgPrecision /= numLabels;
|
|
|
|
|
*macroAvgRecall /= numLabels;
|
|
|
|
|
*macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall);
|
|
|
|
|
|
|
|
|
|
*microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
|
|
|
|
|
*microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
|
|
|
|
|
*microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall);
|
|
|
|
|
info->macroAvgPrecision /= numLabels;
|
|
|
|
|
info->macroAvgRecall /= numLabels;
|
|
|
|
|
info->macroAvgF1Score =
|
|
|
|
|
calcF1Score(info->macroAvgPrecision, info->macroAvgRecall);
|
|
|
|
|
|
|
|
|
|
info->microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
|
|
|
|
|
info->microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
|
|
|
|
|
info->microAvgF1Score =
|
|
|
|
|
calcF1Score(info->microAvgPrecision, info->microAvgRecall);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|