|
|
|
@ -26,6 +26,7 @@ private:
|
|
|
|
|
int numTimes_, numClasses_, numSequences_, blank_;
|
|
|
|
|
real deletions_, insertions_, substitutions_;
|
|
|
|
|
int seqClassficationError_;
|
|
|
|
|
mutable std::unordered_map<std::string, real> evalResults_;
|
|
|
|
|
|
|
|
|
|
std::vector<int> path2String(const std::vector<int>& path) {
|
|
|
|
|
std::vector<int> str;
|
|
|
|
@ -183,6 +184,18 @@ private:
|
|
|
|
|
return stringAlignment(gtStr, recogStr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void storeLocalValues() const {
|
|
|
|
|
evalResults_["error"] = numSequences_ ? totalScore_ / numSequences_ : 0;
|
|
|
|
|
evalResults_["deletion_error"] =
|
|
|
|
|
numSequences_ ? deletions_ / numSequences_ : 0;
|
|
|
|
|
evalResults_["insertion_error"] =
|
|
|
|
|
numSequences_ ? insertions_ / numSequences_ : 0;
|
|
|
|
|
evalResults_["substitution_error"] =
|
|
|
|
|
numSequences_ ? substitutions_ / numSequences_ : 0;
|
|
|
|
|
evalResults_["sequence_error"] =
|
|
|
|
|
(real)seqClassficationError_ / numSequences_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
CTCErrorEvaluator()
|
|
|
|
|
: numTimes_(0),
|
|
|
|
@ -245,16 +258,12 @@ public:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void printStats(std::ostream& os) const {
|
|
|
|
|
os << config_.name() << "="
|
|
|
|
|
<< (numSequences_ ? totalScore_ / numSequences_ : 0);
|
|
|
|
|
os << " deletions error"
|
|
|
|
|
<< "=" << (numSequences_ ? deletions_ / numSequences_ : 0);
|
|
|
|
|
os << " insertions error"
|
|
|
|
|
<< "=" << (numSequences_ ? insertions_ / numSequences_ : 0);
|
|
|
|
|
os << " substitutions error"
|
|
|
|
|
<< "=" << (numSequences_ ? substitutions_ / numSequences_ : 0);
|
|
|
|
|
os << " sequences error"
|
|
|
|
|
<< "=" << (real)seqClassficationError_ / numSequences_;
|
|
|
|
|
storeLocalValues();
|
|
|
|
|
os << config_.name() << "=" << evalResults_["error"];
|
|
|
|
|
os << " deletions error = " << evalResults_["deletion_error"];
|
|
|
|
|
os << " insertions error = " << evalResults_["insertion_error"];
|
|
|
|
|
os << " substitution error = " << evalResults_["substitution_error"];
|
|
|
|
|
os << " sequence error = " << evalResults_["sequence_error"];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void distributeEval(ParameterClient2* client) {
|
|
|
|
@ -272,6 +281,33 @@ public:
|
|
|
|
|
seqClassficationError_ = (int)buf[4];
|
|
|
|
|
numSequences_ = (int)buf[5];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void getNames(std::vector<std::string>* names) {
|
|
|
|
|
storeLocalValues();
|
|
|
|
|
names->reserve(names->size() + evalResults_.size());
|
|
|
|
|
for (auto it = evalResults_.begin(); it != evalResults_.end(); ++it) {
|
|
|
|
|
names->push_back(config_.name() + "." + it->first);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
real getValue(const std::string& name, Error* err) const {
|
|
|
|
|
storeLocalValues();
|
|
|
|
|
|
|
|
|
|
const std::string delimiter(".");
|
|
|
|
|
std::string::size_type foundPos = name.find(delimiter, 0);
|
|
|
|
|
CHECK(foundPos != std::string::npos);
|
|
|
|
|
|
|
|
|
|
auto it = evalResults_.find(
|
|
|
|
|
name.substr(foundPos + delimiter.size(), name.length()));
|
|
|
|
|
if (it == evalResults_.end()) {
|
|
|
|
|
*err = Error("Evaluator does not have the key %s", name.c_str());
|
|
|
|
|
return 0.0f;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string getTypeImpl() const { return "ctc_edit_distance"; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);
|
|
|
|
|