|
|
|
@ -626,78 +626,34 @@ real PrecisionRecallEvaluator::evalImp(std::vector<Argument>& arguments) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T1, typename T2>
|
|
|
|
|
void PrecisionRecallEvaluator::printStatsHelper(T1 labelCallback,
|
|
|
|
|
T2 microAvgCallback) const {
|
|
|
|
|
int label = config_.positive_label();
|
|
|
|
|
if (label != -1) {
|
|
|
|
|
CHECK(label >= 0 && label < (int)statsInfo_.size())
|
|
|
|
|
<< "positive_label [" << label << "] should be in range [0, "
|
|
|
|
|
<< statsInfo_.size() << ")";
|
|
|
|
|
double precision =
|
|
|
|
|
calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
|
|
|
|
|
double recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
|
|
|
|
|
labelCallback(label, precision, recall, calcF1Score(precision, recall));
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
|
|
|
|
|
// macro average method: precision = (precision1+precision2)/2
|
|
|
|
|
double microTotalTP = 0;
|
|
|
|
|
double microTotalFP = 0;
|
|
|
|
|
double microTotalFN = 0;
|
|
|
|
|
double macroAvgPrecision = 0;
|
|
|
|
|
double macroAvgRecall = 0;
|
|
|
|
|
size_t numLabels = statsInfo_.size();
|
|
|
|
|
for (size_t i = 0; i < numLabels; ++i) {
|
|
|
|
|
microTotalTP += statsInfo_[i].TP;
|
|
|
|
|
microTotalFP += statsInfo_[i].FP;
|
|
|
|
|
microTotalFN += statsInfo_[i].FN;
|
|
|
|
|
macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
|
|
|
|
|
macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
|
|
|
|
|
}
|
|
|
|
|
macroAvgPrecision /= numLabels;
|
|
|
|
|
macroAvgRecall /= numLabels;
|
|
|
|
|
double macroAvgF1Score = calcF1Score(macroAvgPrecision, macroAvgRecall);
|
|
|
|
|
|
|
|
|
|
double microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
|
|
|
|
|
double microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
|
|
|
|
|
double microAvgF1Score = calcF1Score(microAvgPrecision, microAvgRecall);
|
|
|
|
|
|
|
|
|
|
microAvgCallback(macroAvgPrecision,
|
|
|
|
|
macroAvgRecall,
|
|
|
|
|
macroAvgF1Score,
|
|
|
|
|
isMultiBinaryLabel_,
|
|
|
|
|
microAvgPrecision,
|
|
|
|
|
microAvgRecall,
|
|
|
|
|
microAvgF1Score);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrecisionRecallEvaluator::printStats(std::ostream& os) const {
|
|
|
|
|
this->printStatsHelper(
|
|
|
|
|
[&os](int label, double precision, double recall, double f1) {
|
|
|
|
|
os << "positive_label=" << label << " precision=" << precision
|
|
|
|
|
<< " recall=" << recall << " F1-score=" << f1;
|
|
|
|
|
},
|
|
|
|
|
[&os](double macroAvgPrecision,
|
|
|
|
|
double macroAvgRecall,
|
|
|
|
|
double macroAvgF1Score,
|
|
|
|
|
bool isMultiBinaryLabel,
|
|
|
|
|
double microAvgPrecision,
|
|
|
|
|
double microAvgRecall,
|
|
|
|
|
double microAvgF1Score) {
|
|
|
|
|
os << "macro-average-precision=" << macroAvgPrecision
|
|
|
|
|
<< " macro-average-recall=" << macroAvgRecall
|
|
|
|
|
<< " macro-average-F1-score=" << macroAvgF1Score;
|
|
|
|
|
if (!isMultiBinaryLabel) {
|
|
|
|
|
// precision and recall are equal in this case
|
|
|
|
|
os << " micro-average-precision=" << microAvgPrecision;
|
|
|
|
|
} else {
|
|
|
|
|
os << " micro-average-precision=" << microAvgPrecision
|
|
|
|
|
<< " micro-average-recall=" << microAvgRecall
|
|
|
|
|
<< " micro-average-F1-score=" << microAvgF1Score;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
|
|
|
|
|
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
|
|
|
|
|
bool containMacroMicroInfo = getStatsInfo(&precision,
|
|
|
|
|
&recall,
|
|
|
|
|
&f1,
|
|
|
|
|
¯oAvgPrecision,
|
|
|
|
|
¯oAvgRecall,
|
|
|
|
|
¯oAvgF1Score,
|
|
|
|
|
µAvgPrecision,
|
|
|
|
|
µAvgRecall,
|
|
|
|
|
µAvgF1Score);
|
|
|
|
|
os << "positive_label=" << config_.positive_label()
|
|
|
|
|
<< " precision=" << precision << " recall=" << recall
|
|
|
|
|
<< " F1-score=" << f1;
|
|
|
|
|
if (containMacroMicroInfo) {
|
|
|
|
|
os << "macro-average-precision=" << macroAvgPrecision
|
|
|
|
|
<< " macro-average-recall=" << macroAvgRecall
|
|
|
|
|
<< " macro-average-F1-score=" << macroAvgF1Score;
|
|
|
|
|
if (!isMultiBinaryLabel_) {
|
|
|
|
|
// precision and recall are equal in this case
|
|
|
|
|
os << " micro-average-precision=" << microAvgPrecision;
|
|
|
|
|
} else {
|
|
|
|
|
os << " micro-average-precision=" << microAvgPrecision
|
|
|
|
|
<< " micro-average-recall=" << microAvgRecall
|
|
|
|
|
<< " micro-average-F1-score=" << microAvgF1Score;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrecisionRecallEvaluator::calcStatsInfo(const MatrixPtr& output,
|
|
|
|
@ -780,32 +736,33 @@ void PrecisionRecallEvaluator::calcStatsInfoMulti(const MatrixPtr& output,
|
|
|
|
|
|
|
|
|
|
void PrecisionRecallEvaluator::storeLocalValues() const {
|
|
|
|
|
if (this->values_.size() == 0) {
|
|
|
|
|
this->printStatsHelper(
|
|
|
|
|
[this](int label, double precision, double recall, double f1) {
|
|
|
|
|
values_["positive_label"] = (double)label;
|
|
|
|
|
values_["precision"] = precision;
|
|
|
|
|
values_["recal"] = recall;
|
|
|
|
|
values_["F1-score"] = f1;
|
|
|
|
|
},
|
|
|
|
|
[this](double macroAvgPrecision,
|
|
|
|
|
double macroAvgRecall,
|
|
|
|
|
double macroAvgF1Score,
|
|
|
|
|
bool isMultiBinaryLabel,
|
|
|
|
|
double microAvgPrecision,
|
|
|
|
|
double microAvgRecall,
|
|
|
|
|
double microAvgF1Score) {
|
|
|
|
|
values_["macro-average-precision"] = macroAvgPrecision;
|
|
|
|
|
values_["macro-average-recall"] = macroAvgRecall;
|
|
|
|
|
values_["macro-average-F1-score"] = macroAvgF1Score;
|
|
|
|
|
if (!isMultiBinaryLabel) {
|
|
|
|
|
// precision and recall are equal in this case
|
|
|
|
|
values_["micro-average-precision"] = microAvgPrecision;
|
|
|
|
|
} else {
|
|
|
|
|
values_["micro-average-precision"] = microAvgPrecision;
|
|
|
|
|
values_["micro-average-recall"] = microAvgRecall;
|
|
|
|
|
values_["micro-average-F1-score"] = microAvgF1Score;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
double precision, recall, f1, macroAvgPrecision, macroAvgRecall,
|
|
|
|
|
macroAvgF1Score, microAvgPrecision, microAvgRecall, microAvgF1Score;
|
|
|
|
|
bool containMacroMicroInfo = getStatsInfo(&precision,
|
|
|
|
|
&recall,
|
|
|
|
|
&f1,
|
|
|
|
|
¯oAvgPrecision,
|
|
|
|
|
¯oAvgRecall,
|
|
|
|
|
¯oAvgF1Score,
|
|
|
|
|
µAvgPrecision,
|
|
|
|
|
µAvgRecall,
|
|
|
|
|
µAvgF1Score);
|
|
|
|
|
values_["precision"] = precision;
|
|
|
|
|
values_["recal"] = recall;
|
|
|
|
|
values_["F1-score"] = f1;
|
|
|
|
|
if (containMacroMicroInfo) {
|
|
|
|
|
values_["macro-average-precision"] = macroAvgPrecision;
|
|
|
|
|
values_["macro-average-recall"] = macroAvgRecall;
|
|
|
|
|
values_["macro-average-F1-score"] = macroAvgF1Score;
|
|
|
|
|
if (!isMultiBinaryLabel_) {
|
|
|
|
|
// precision and recall are equal in this case
|
|
|
|
|
values_["micro-average-precision"] = microAvgPrecision;
|
|
|
|
|
} else {
|
|
|
|
|
values_["micro-average-precision"] = microAvgPrecision;
|
|
|
|
|
values_["micro-average-recall"] = microAvgRecall;
|
|
|
|
|
values_["micro-average-F1-score"] = microAvgF1Score;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -865,6 +822,51 @@ void PrecisionRecallEvaluator::distributeEval(ParameterClient2* client) {
|
|
|
|
|
delete[] buf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool PrecisionRecallEvaluator::getStatsInfo(double* precision,
|
|
|
|
|
double* recall,
|
|
|
|
|
double* f1,
|
|
|
|
|
double* macroAvgPrecision,
|
|
|
|
|
double* macroAvgRecall,
|
|
|
|
|
double* macroAvgF1Score,
|
|
|
|
|
double* microAvgPrecision,
|
|
|
|
|
double* microAvgRecall,
|
|
|
|
|
double* microAvgF1Score) const {
|
|
|
|
|
int label = config_.positive_label();
|
|
|
|
|
if (label != -1) {
|
|
|
|
|
CHECK(label >= 0 && label < (int)statsInfo_.size())
|
|
|
|
|
<< "positive_label [" << label << "] should be in range [0, "
|
|
|
|
|
<< statsInfo_.size() << ")";
|
|
|
|
|
*precision = calcPrecision(statsInfo_[label].TP, statsInfo_[label].FP);
|
|
|
|
|
*recall = calcRecall(statsInfo_[label].TP, statsInfo_[label].FN);
|
|
|
|
|
*f1 = calcF1Score(*precision, *recall);
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// micro average method: precision = (TP1+TP2)/(TP1+FP1+TP2+FP2)
|
|
|
|
|
// macro average method: precision = (precision1+precision2)/2
|
|
|
|
|
double microTotalTP = 0;
|
|
|
|
|
double microTotalFP = 0;
|
|
|
|
|
double microTotalFN = 0;
|
|
|
|
|
*macroAvgPrecision = 0;
|
|
|
|
|
*macroAvgRecall = 0;
|
|
|
|
|
size_t numLabels = statsInfo_.size();
|
|
|
|
|
for (size_t i = 0; i < numLabels; ++i) {
|
|
|
|
|
microTotalTP += statsInfo_[i].TP;
|
|
|
|
|
microTotalFP += statsInfo_[i].FP;
|
|
|
|
|
microTotalFN += statsInfo_[i].FN;
|
|
|
|
|
*macroAvgPrecision += calcPrecision(statsInfo_[i].TP, statsInfo_[i].FP);
|
|
|
|
|
*macroAvgRecall += calcRecall(statsInfo_[i].TP, statsInfo_[i].FN);
|
|
|
|
|
}
|
|
|
|
|
*macroAvgPrecision /= numLabels;
|
|
|
|
|
*macroAvgRecall /= numLabels;
|
|
|
|
|
*macroAvgF1Score = calcF1Score(*macroAvgPrecision, *macroAvgRecall);
|
|
|
|
|
|
|
|
|
|
*microAvgPrecision = calcPrecision(microTotalTP, microTotalFP);
|
|
|
|
|
*microAvgRecall = calcPrecision(microTotalTP, microTotalFN);
|
|
|
|
|
*microAvgF1Score = calcF1Score(*microAvgPrecision, *microAvgRecall);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_EVALUATOR(pnpair, PnpairEvaluator);
|
|
|
|
|
void PnpairEvaluator::start() {
|
|
|
|
|
Evaluator::start();
|
|
|
|
|