Added class ITrain interface.

v1.6alpha
Micooz 10 years ago
parent 7cf0392991
commit cac3427261

@ -34,6 +34,7 @@ set(SOURCE_FILES
src/core/plate_recognize.cpp
src/train/ann_train.cpp
src/train/svm_train.cpp
src/train/train.cpp
src/preprocess/deface.cpp
src/preprocess/gdts.cpp
src/preprocess/mc_data.cpp

@ -10,5 +10,6 @@
#include "easypr/util/util.h"
#include "easypr/util/program_options.h"
#include "easypr/api.hpp"
#include "easypr/config.h"
#endif //EASYPR_EASYPR_H

@ -1,12 +1,19 @@
#ifndef EASYPR_TRAIN_TRAIN_H_
#define EASYPR_TRAIN_TRAIN_H_
#ifndef EASYPR_CONFIG_H_
#define EASYPR_CONFIG_H_
namespace easypr {
static const char* kDefaultSvmPath = "resources/model/svm.xml";
static const char* kDefaultAnnPath = "resources/model/ann.xml";
static const int kPredictSize = 10;
typedef enum {
kForward = 1, // correspond to "has plate"
kInverse = 0 // correspond to "no plate"
} SvmLabel;
static const float kSvmPercentage = 0.7f;
static const int kPredictSize = 10;
static const int kNeurons = 40;
static const char *kChars[] = {
"0", "1", "2",
@ -59,4 +66,4 @@ static bool kDebug = false;
}
#endif // EASYPR_TRAIN_TRAIN_H_
#endif // EASYPR_CONFIG_H_

@ -1,25 +1,25 @@
#ifndef EASYPR_TRAIN_ANNTRAIN_H_
#define EASYPR_TRAIN_ANNTRAIN_H_
#include <opencv2/opencv.hpp>
#include "easypr/train/train.h"
namespace easypr{
namespace easypr {
class AnnTrain{
public:
class AnnTrain: public ITrain {
public:
explicit AnnTrain(const char* chars_folder, const char* xml);
void train(const int & neurons = 40);
virtual void train();
void test();
virtual void test();
private:
cv::Ptr<cv::ml::TrainData> train_data();
cv::Ptr<cv::ml::ANN_MLP> ann_;
private:
virtual cv::Ptr<cv::ml::TrainData> tdata();
cv::Ptr<cv::ml::ANN_MLP> ann_;
const char* ann_xml_;
const char* chars_folder_;
};
};
}

@ -0,0 +1,23 @@
#ifndef EASYPR_TRAIN_TRAIN_H_
#define EASYPR_TRAIN_TRAIN_H_
#include <opencv2/opencv.hpp>
namespace easypr {
class ITrain {
public:
ITrain();
virtual ~ITrain();
virtual void train() = 0;
virtual void test() = 0;
private:
virtual cv::Ptr<cv::ml::TrainData> tdata() = 0;
};
}
#endif // EASYPR_TRAIN_TRAIN_H_

@ -13,10 +13,10 @@ AnnTrain::AnnTrain(const char* chars_folder,
ann_ = cv::ml::ANN_MLP::create();
}
void AnnTrain::train(const int& neurons /* = 40 */) {
void AnnTrain::train() {
cv::Mat layers(1, 3, CV_32SC1);
layers.at<int>(0) = 120; // the input layer
layers.at<int>(1) = neurons; // the neurons
layers.at<int>(1) = kNeurons; // the neurons
layers.at<int>(2) = kCharsTotalNumber; // the output layer
ann_->setLayerSizes(layers);
@ -25,7 +25,7 @@ void AnnTrain::train(const int& neurons /* = 40 */) {
ann_->setBackpropWeightScale(0.1);
ann_->setBackpropMomentumScale(0.1);
auto traindata = train_data();
auto traindata = tdata();
std::cout << "Training ANN model, please wait..." << std::endl;
long start = utils::getTimestamp();
ann_->train(traindata);
@ -36,45 +36,6 @@ void AnnTrain::train(const int& neurons /* = 40 */) {
std::cout << "Your ANN Model was saved to " << ann_xml_ << std::endl;
}
cv::Ptr<cv::ml::TrainData> AnnTrain::train_data() {
assert(chars_folder_);
cv::Mat samples;
std::vector<int> labels;
std::cout << "Collecting chars in " << chars_folder_ << std::endl;
for (int i = 0; i < kCharsTotalNumber; ++i) {
auto char_key = kChars[i];
char sub_folder[512] = {0};
sprintf(sub_folder, "%s/%s", chars_folder_, char_key);
std::cout << " >> Featuring characters " << char_key << " in " << sub_folder << std::endl;
auto chars_files = utils::getFiles(sub_folder);
for (auto file : chars_files) {
auto img = cv::imread(file, 0); // a grayscale image
auto fps = features(img, kPredictSize);
samples.push_back(fps);
labels.push_back(i);
}
}
cv::Mat samples_;
samples.convertTo(samples_, CV_32F);
cv::Mat train_classes = cv::Mat::zeros((int) labels.size(), kCharsTotalNumber,
CV_32F);
for (int i = 0; i < train_classes.rows; ++i) {
train_classes.at<float>(i, labels[i]) = 1.f;
}
return cv::ml::TrainData::create(samples_,
cv::ml::SampleTypes::ROW_SAMPLE,
train_classes);
}
void AnnTrain::test() {
assert(chars_folder_);
@ -120,4 +81,43 @@ void AnnTrain::test() {
}
}
cv::Ptr<cv::ml::TrainData> AnnTrain::tdata() {
assert(chars_folder_);
cv::Mat samples;
std::vector<int> labels;
std::cout << "Collecting chars in " << chars_folder_ << std::endl;
for (int i = 0; i < kCharsTotalNumber; ++i) {
auto char_key = kChars[i];
char sub_folder[512] = {0};
sprintf(sub_folder, "%s/%s", chars_folder_, char_key);
std::cout << " >> Featuring characters " << char_key << " in " << sub_folder << std::endl;
auto chars_files = utils::getFiles(sub_folder);
for (auto file : chars_files) {
auto img = cv::imread(file, 0); // a grayscale image
auto fps = features(img, kPredictSize);
samples.push_back(fps);
labels.push_back(i);
}
}
cv::Mat samples_;
samples.convertTo(samples_, CV_32F);
cv::Mat train_classes = cv::Mat::zeros((int) labels.size(), kCharsTotalNumber,
CV_32F);
for (int i = 0; i < train_classes.rows; ++i) {
train_classes.at<float>(i, labels[i]) = 1.f;
}
return cv::ml::TrainData::create(samples_,
cv::ml::SampleTypes::ROW_SAMPLE,
train_classes);
}
}

@ -0,0 +1,9 @@
#include "easypr/train/train.h"
namespace easypr{
ITrain::ITrain() {}
ITrain::~ITrain() {}
}
Loading…
Cancel
Save