From 66adeedf5da99bd5bfa103791fc33afd03abdd18 Mon Sep 17 00:00:00 2001
From: Micooz <Micooz@hotmail.com>
Date: Tue, 11 Aug 2015 20:06:20 +0800
Subject: [PATCH] Finished ANN training for opencv3.0.

---
 src/train/ann_train.cpp | 192 +++++++++++++++++++++++-----------------
 1 file changed, 113 insertions(+), 79 deletions(-)

diff --git a/src/train/ann_train.cpp b/src/train/ann_train.cpp
index 19ad9ff..bdbeac6 100644
--- a/src/train/ann_train.cpp
+++ b/src/train/ann_train.cpp
@@ -1,88 +1,122 @@
-#include "easypr/train/ann_train.h"
+#include "easypr/train/ann_train.h"
+#include "easypr/core/core_func.h"
+#include "easypr/core/chars_identify.h"
+#include "easypr/util/util.h"
 #include "easypr/config.h"
-
-namespace easypr{
-
-  AnnTrain::AnnTrain(const char* chars_folder, const char* zhchars_folder, const char* xml)
-    :chars_folder_(chars_folder), zhchars_folder_(zhchars_folder), ann_xml_(xml){
-
-  }
-
-  void AnnTrain::train(const int & neurons /* = 40 */){
-    this->getTrainData();
-
-    cv::Mat layers = {
-      train_data_->getNSamples(),// the input layer
-      neurons, // the neurons
-      sizeof(kChinese) + sizeof(kCharacters) // the output layer
-    };
-    ann_->setLayerSizes(layers);
-    ann_->setTrainMethod(cv::ml::ANN_MLP::TrainingMethods::BACKPROP);
-    ann_->setBackpropWeightScale(0.1);
-    ann_->setBackpropMomentumScale(0.1);
-
-    std::cout << "Training ANN model, please wait..." << std::endl;
-    long start = utils::getTimestamp();
-    ann_->train(train_data_);
-    long end = utils::getTimestamp();
-    std::cout << "Training done. Elapse: " << (end - start) / 1000 << std::endl;
-
-    ann_->save(ann_xml_);
-    std::cout << "Your ANN Model was saved to " << ann_xml_ << std::endl;
-  }
-
-  void AnnTrain::getTrainData(){
-    assert(chars_folder_);
-    assert(zhchars_folder_);
-    // create new
-    cv::Mat samples;
-    cv::Mat responses;
-
-    std::cout << "Collecting chars in " << chars_folder_ << std::endl;
-
-    for (auto i = 0; i < sizeof(kCharacters); ++i){
-      char c = kCharacters[i];
-      char sub_folder[512] = { 0 };
-
-      sprintf(sub_folder, "%s/%c", chars_folder_, c);
-      std::cout << "  >> Featuring characters " << c << " in " << sub_folder << std::endl;
-
-      auto chars_files = utils::getFiles(sub_folder);
-      for (auto file : chars_files){
-        auto img = cv::imread(file);
-        auto fps = features(img, kPredictSize);
-
-        samples.push_back(fps);
-        responses.push_back(i);
-      }
-    }
 
-    std::cout << "Collecting zh-chars in " << zhchars_folder_ << std::endl;
+namespace easypr {
+
+AnnTrain::AnnTrain(const char* chars_folder,
+                   const char* xml)
+        : chars_folder_(chars_folder),
+          ann_xml_(xml) {
+  ann_ = cv::ml::ANN_MLP::create();
+}
+
+void AnnTrain::train(const int& neurons /* = 40 */) {
+  cv::Mat layers(1, 3, CV_32SC1);
+  layers.at<int>(0) = 120; // the input layer
+  layers.at<int>(1) = neurons; // the neurons
+  layers.at<int>(2) = kCharsTotalNumber; // the output layer
+
+  ann_->setLayerSizes(layers);
+  ann_->setActivationFunction(cv::ml::ANN_MLP::SIGMOID_SYM, 1, 1);
+  ann_->setTrainMethod(cv::ml::ANN_MLP::TrainingMethods::BACKPROP);
+  ann_->setBackpropWeightScale(0.1);
+  ann_->setBackpropMomentumScale(0.1);
+
+  auto traindata = train_data();
+  std::cout << "Training ANN model, please wait..." << std::endl;
+  long start = utils::getTimestamp();
+  ann_->train(traindata);
+  long end = utils::getTimestamp();
+  std::cout << "Training done. Time elapse: " << (end - start) << "ms" << std::endl;
+
+  ann_->save(ann_xml_);
+  std::cout << "Your ANN Model was saved to " << ann_xml_ << std::endl;
+}
+
+cv::Ptr<cv::ml::TrainData> AnnTrain::train_data() {
+  assert(chars_folder_);
 
-    for (auto i = 0; i < sizeof(kChinese); ++i){
-      const char *zhc = kChinese[i];
-      char sub_folder[512] = { 0 };
+  cv::Mat samples;
+  std::vector<int> labels;
 
-      sprintf(sub_folder, "%s/%s", zhchars_folder_, zhc);
-      std::cout << "  >> Featuring zh-characters " << zhc << " in " << sub_folder << std::endl;
+  std::cout << "Collecting chars in " << chars_folder_ << std::endl;
 
-      auto chars_files = utils::getFiles(sub_folder);
-      for (auto file : chars_files){
-        auto img = cv::imread(file);
-        auto fps = features(img, kPredictSize);
+  for (int i = 0; i < kCharsTotalNumber; ++i) {
+    auto char_key = kChars[i];
+    char sub_folder[512] = {0};
 
-        samples.push_back(fps);
-        responses.push_back(i + sizeof(kCharacters));
+    sprintf(sub_folder, "%s/%s", chars_folder_, char_key);
+    std::cout << "  >> Featuring characters " << char_key << " in " << sub_folder << std::endl;
+
+    auto chars_files = utils::getFiles(sub_folder);
+    for (auto file : chars_files) {
+      auto img = cv::imread(file);
+      auto fps = features(img, kPredictSize);
+
+      samples.push_back(fps);
+      labels.push_back(i);
+    }
+  }
+
+  cv::Mat samples_;
+  samples.convertTo(samples_, CV_32F);
+  cv::Mat train_classes = cv::Mat::zeros((int) labels.size(), kCharsTotalNumber,
+                                   CV_32F);
+
+  for (int i = 0; i < train_classes.rows; ++i) {
+    train_classes.at<float>(i, labels[i]) = 1.f;
+  }
+
+  return cv::ml::TrainData::create(samples_,
+                               cv::ml::SampleTypes::ROW_SAMPLE,
+                               train_classes);
+}
+
+void AnnTrain::test() {
+  assert(chars_folder_);
+
+  for (int i = 0; i < kCharsTotalNumber; ++i) {
+    auto char_key = kChars[i];
+    char sub_folder[512] = {0};
+
+    sprintf(sub_folder, "%s/%s", chars_folder_, char_key);
+    fprintf(stdout, "  >> Testing characters %s in %s \n",
+            char_key, sub_folder);
+
+    auto chars_files = utils::getFiles(sub_folder);
+    int corrects = 0, sum = 0;
+    std::vector<std::string> error_files;
+
+    for (auto file : chars_files) {
+      auto img = cv::imread(file);
+      std::pair<std::string, std::string> ch = CharsIdentify::instance()->identify(img);
+      if (ch.first == char_key) {
+        ++corrects;
+      } else {
+        error_files.push_back(utils::getFileName(file));
       }
+      ++sum;
     }
-    //
-    cv::Mat samples_row;
-    cv::Mat responses_row;
-
-    samples.convertTo(samples_row, CV_32FC1);
-    responses.convertTo(responses_row, CV_32FC1);
+    fprintf(stdout, "  >>   [sum: %d, correct: %d, rate: %.2f]\n",
+            sum, corrects, (float)corrects / (sum == 0 ? 1 : sum));
+    std::string error_string;
+    auto end = error_files.end();
+    if (error_files.size() >= 10) {
+      end -= error_files.size() * (1 - 0.1);
+    }
+    for(auto i = error_files.begin(); i != end; ++i) {
+      error_string.append(*i);
+      if (i != end - 1) {
+        error_string.append(", ");
+      } else {
+        error_string.append(" ...");
+      }
+    }
+    fprintf(stdout, "  >>   [%s]\n", error_string.c_str());
+  }
+}
 
-    train_data_ = cv::ml::TrainData::create(samples_row, ml::SampleTypes::ROW_SAMPLE, responses_row);
-  }
-
-}
+}