From 66adeedf5da99bd5bfa103791fc33afd03abdd18 Mon Sep 17 00:00:00 2001 From: Micooz <Micooz@hotmail.com> Date: Tue, 11 Aug 2015 20:06:20 +0800 Subject: [PATCH] Finished ANN training for opencv3.0. --- src/train/ann_train.cpp | 192 +++++++++++++++++++++++----------------- 1 file changed, 113 insertions(+), 79 deletions(-) diff --git a/src/train/ann_train.cpp b/src/train/ann_train.cpp index 19ad9ff..bdbeac6 100644 --- a/src/train/ann_train.cpp +++ b/src/train/ann_train.cpp @@ -1,88 +1,122 @@ -#include "easypr/train/ann_train.h" +#include "easypr/train/ann_train.h" +#include "easypr/core/core_func.h" +#include "easypr/core/chars_identify.h" +#include "easypr/util/util.h" #include "easypr/config.h" - -namespace easypr{ - - AnnTrain::AnnTrain(const char* chars_folder, const char* zhchars_folder, const char* xml) - :chars_folder_(chars_folder), zhchars_folder_(zhchars_folder), ann_xml_(xml){ - - } - - void AnnTrain::train(const int & neurons /* = 40 */){ - this->getTrainData(); - - cv::Mat layers = { - train_data_->getNSamples(),// the input layer - neurons, // the neurons - sizeof(kChinese) + sizeof(kCharacters) // the output layer - }; - ann_->setLayerSizes(layers); - ann_->setTrainMethod(cv::ml::ANN_MLP::TrainingMethods::BACKPROP); - ann_->setBackpropWeightScale(0.1); - ann_->setBackpropMomentumScale(0.1); - - std::cout << "Training ANN model, please wait..." << std::endl; - long start = utils::getTimestamp(); - ann_->train(train_data_); - long end = utils::getTimestamp(); - std::cout << "Training done. Elapse: " << (end - start) / 1000 << std::endl; - - ann_->save(ann_xml_); - std::cout << "Your ANN Model was saved to " << ann_xml_ << std::endl; - } - - void AnnTrain::getTrainData(){ - assert(chars_folder_); - assert(zhchars_folder_); - // create new - cv::Mat samples; - cv::Mat responses; - - std::cout << "Collecting chars in " << chars_folder_ << std::endl; - - for (auto i = 0; i < sizeof(kCharacters); ++i){ - char c = kCharacters[i]; - char sub_folder[512] = { 0 }; - - sprintf(sub_folder, "%s/%c", chars_folder_, c); - std::cout << " >> Featuring characters " << c << " in " << sub_folder << std::endl; - - auto chars_files = utils::getFiles(sub_folder); - for (auto file : chars_files){ - auto img = cv::imread(file); - auto fps = features(img, kPredictSize); - - samples.push_back(fps); - responses.push_back(i); - } - } - std::cout << "Collecting zh-chars in " << zhchars_folder_ << std::endl; +namespace easypr { + +AnnTrain::AnnTrain(const char* chars_folder, + const char* xml) + : chars_folder_(chars_folder), + ann_xml_(xml) { + ann_ = cv::ml::ANN_MLP::create(); +} + +void AnnTrain::train(const int& neurons /* = 40 */) { + cv::Mat layers(1, 3, CV_32SC1); + layers.at<int>(0) = 120; // the input layer + layers.at<int>(1) = neurons; // the neurons + layers.at<int>(2) = kCharsTotalNumber; // the output layer + + ann_->setLayerSizes(layers); + ann_->setActivationFunction(cv::ml::ANN_MLP::SIGMOID_SYM, 1, 1); + ann_->setTrainMethod(cv::ml::ANN_MLP::TrainingMethods::BACKPROP); + ann_->setBackpropWeightScale(0.1); + ann_->setBackpropMomentumScale(0.1); + + auto traindata = train_data(); + std::cout << "Training ANN model, please wait..." << std::endl; + long start = utils::getTimestamp(); + ann_->train(traindata); + long end = utils::getTimestamp(); + std::cout << "Training done. Time elapse: " << (end - start) << "ms" << std::endl; + + ann_->save(ann_xml_); + std::cout << "Your ANN Model was saved to " << ann_xml_ << std::endl; +} + +cv::Ptr<cv::ml::TrainData> AnnTrain::train_data() { + assert(chars_folder_); - for (auto i = 0; i < sizeof(kChinese); ++i){ - const char *zhc = kChinese[i]; - char sub_folder[512] = { 0 }; + cv::Mat samples; + std::vector<int> labels; - sprintf(sub_folder, "%s/%s", zhchars_folder_, zhc); - std::cout << " >> Featuring zh-characters " << zhc << " in " << sub_folder << std::endl; + std::cout << "Collecting chars in " << chars_folder_ << std::endl; - auto chars_files = utils::getFiles(sub_folder); - for (auto file : chars_files){ - auto img = cv::imread(file); - auto fps = features(img, kPredictSize); + for (int i = 0; i < kCharsTotalNumber; ++i) { + auto char_key = kChars[i]; + char sub_folder[512] = {0}; - samples.push_back(fps); - responses.push_back(i + sizeof(kCharacters)); + sprintf(sub_folder, "%s/%s", chars_folder_, char_key); + std::cout << " >> Featuring characters " << char_key << " in " << sub_folder << std::endl; + + auto chars_files = utils::getFiles(sub_folder); + for (auto file : chars_files) { + auto img = cv::imread(file); + auto fps = features(img, kPredictSize); + + samples.push_back(fps); + labels.push_back(i); + } + } + + cv::Mat samples_; + samples.convertTo(samples_, CV_32F); + cv::Mat train_classes = cv::Mat::zeros((int) labels.size(), kCharsTotalNumber, + CV_32F); + + for (int i = 0; i < train_classes.rows; ++i) { + train_classes.at<float>(i, labels[i]) = 1.f; + } + + return cv::ml::TrainData::create(samples_, + cv::ml::SampleTypes::ROW_SAMPLE, + train_classes); +} + +void AnnTrain::test() { + assert(chars_folder_); + + for (int i = 0; i < kCharsTotalNumber; ++i) { + auto char_key = kChars[i]; + char sub_folder[512] = {0}; + + sprintf(sub_folder, "%s/%s", chars_folder_, char_key); + fprintf(stdout, " >> Testing characters %s in %s \n", + char_key, sub_folder); + + auto chars_files = utils::getFiles(sub_folder); + int corrects = 0, sum = 0; + std::vector<std::string> error_files; + + for (auto file : chars_files) { + auto img = cv::imread(file); + std::pair<std::string, std::string> ch = CharsIdentify::instance()->identify(img); + if (ch.first == char_key) { + ++corrects; + } else { + error_files.push_back(utils::getFileName(file)); } + ++sum; } - // - cv::Mat samples_row; - cv::Mat responses_row; - - samples.convertTo(samples_row, CV_32FC1); - responses.convertTo(responses_row, CV_32FC1); + fprintf(stdout, " >> [sum: %d, correct: %d, rate: %.2f]\n", + sum, corrects, (float)corrects / (sum == 0 ? 1 : sum)); + std::string error_string; + auto end = error_files.end(); + if (error_files.size() >= 10) { + end -= error_files.size() * (1 - 0.1); + } + for(auto i = error_files.begin(); i != end; ++i) { + error_string.append(*i); + if (i != end - 1) { + error_string.append(", "); + } else { + error_string.append(" ..."); + } + } + fprintf(stdout, " >> [%s]\n", error_string.c_str()); + } +} - train_data_ = cv::ml::TrainData::create(samples_row, ml::SampleTypes::ROW_SAMPLE, responses_row); - } - -} +}