|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "Evaluator.h"
|
|
|
|
|
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
|
|
|
|
|
#include "paddle/utils/StringUtil.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
@ -259,7 +260,7 @@ public:
|
|
|
|
|
|
|
|
|
|
virtual void printStats(std::ostream& os) const {
|
|
|
|
|
storeLocalValues();
|
|
|
|
|
os << config_.name() << "=" << evalResults_["error"];
|
|
|
|
|
os << config_.name() << " error = " << evalResults_["error"];
|
|
|
|
|
os << " deletions error = " << evalResults_["deletion_error"];
|
|
|
|
|
os << " insertions error = " << evalResults_["insertion_error"];
|
|
|
|
|
os << " substitution error = " << evalResults_["substitution_error"];
|
|
|
|
@ -293,12 +294,10 @@ public:
|
|
|
|
|
real getValue(const std::string& name, Error* err) const {
|
|
|
|
|
storeLocalValues();
|
|
|
|
|
|
|
|
|
|
const std::string delimiter(".");
|
|
|
|
|
std::string::size_type foundPos = name.find(delimiter, 0);
|
|
|
|
|
CHECK(foundPos != std::string::npos);
|
|
|
|
|
std::vector<std::string> buffers;
|
|
|
|
|
paddle::str::split(name, '.', &buffers);
|
|
|
|
|
auto it = evalResults_.find(buffers[buffers.size() - 1]);
|
|
|
|
|
|
|
|
|
|
auto it = evalResults_.find(
|
|
|
|
|
name.substr(foundPos + delimiter.size(), name.length()));
|
|
|
|
|
if (it == evalResults_.end()) {
|
|
|
|
|
*err = Error("Evaluator does not have the key %s", name.c_str());
|
|
|
|
|
return 0.0f;
|
|
|
|
@ -307,7 +306,11 @@ public:
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string getTypeImpl() const { return "ctc_edit_distance"; }
|
|
|
|
|
std::string getType(const std::string& name, Error* err) const {
|
|
|
|
|
getValue(name, err);
|
|
|
|
|
if (!err->isOK()) return "";
|
|
|
|
|
return "ctc_edit_distance";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
REGISTER_EVALUATOR(ctc_edit_distance, CTCErrorEvaluator);
|
|
|
|
|