Fixed SVM train.

v1.6alpha
Micooz 11 years ago
parent 5990d4a2b0
commit bfe928bec9

@ -1,7 +1,6 @@
#include "easypr/train/svm_train.h"
#include "easypr/config.h"
#include "easypr/core/core_func.h"
#include "easypr/core/plate_judge.h"
#include "easypr/core/feature.h"
#include "easypr/util/util.h"
#ifdef OS_WINDOWS
@ -16,11 +15,10 @@ SvmTrain::SvmTrain(const char* plates_folder, const char* xml)
: plates_folder_(plates_folder), svm_xml_(xml) {
assert(plates_folder);
assert(xml);
srand(unsigned(time(NULL)));
svm_ = cv::ml::SVM::create();
}
void SvmTrain::train() {
svm_ = cv::ml::SVM::create();
svm_->setType(cv::ml::SVM::C_SVC);
svm_->setKernel(cv::ml::SVM::RBF);
svm_->setDegree(0.1);
@ -33,25 +31,32 @@ void SvmTrain::train() {
auto train_data = tdata();
fprintf(stdout, "Training SVM model, please wait...\n");
fprintf(stdout, ">> Training SVM model, please wait...\n");
long start = utils::getTimestamp();
svm_->trainAuto(train_data, 10,
SVM::getDefaultGrid(SVM::C),
SVM::getDefaultGrid(SVM::GAMMA),
SVM::getDefaultGrid(SVM::P),
SVM::getDefaultGrid(SVM::NU),
SVM::getDefaultGrid(SVM::COEF),
SVM::getDefaultGrid(SVM::C),
SVM::getDefaultGrid(SVM::GAMMA),
SVM::getDefaultGrid(SVM::P),
SVM::getDefaultGrid(SVM::NU),
SVM::getDefaultGrid(SVM::COEF),
SVM::getDefaultGrid(SVM::DEGREE),
true
);
long end = utils::getTimestamp();
fprintf(stdout, "Training done. Time elapse: %ldms\n", end - start);
fprintf(stdout, ">> Training done. Time elapse: %ldms\n", end - start);
fprintf(stdout, ">> Saving model file...\n");
svm_->save(svm_xml_);
fprintf(stdout, "Your ANN Model was saved to %s\n", svm_xml_);
fprintf(stdout, ">> Your ANN Model was saved to %s\n", svm_xml_);
fprintf(stdout, ">> Testing...\n");
this->test();
}
void SvmTrain::test() {
this->prepare();
svm_ = cv::ml::SVM::load<cv::ml::SVM>(svm_xml_);
if (test_file_list_.empty()) {
this->prepare();
}
double count_all = test_file_list_.size();
double ptrue_rtrue = 0;
@ -59,16 +64,17 @@ void SvmTrain::test() {
double pfalse_rtrue = 0;
double pfalse_rfalse = 0;
size_t label_index = 0;
for (auto item : test_file_list_) {
auto image = cv::imread(item.file);
auto features = easypr::histeq(image);
features = features.reshape(1, 1);
if (!image.data){
continue;
}
cv::Mat feature;
getHistogramFeatures(image, feature);
feature.reshape(1, 1).convertTo(feature, CV_32F);
int predict;
PlateJudge::instance()->plateJudge(features, predict);
auto predict = static_cast<int>(svm_->predict(feature));
auto real = item.label;
if (predict == kForward && real == kForward) ptrue_rtrue++;
if (predict == kForward && real == kInverse) ptrue_rfalse++;
if (predict == kInverse && real == kForward) pfalse_rtrue++;
@ -87,8 +93,7 @@ void SvmTrain::test() {
std::cout << "precise: " << precise << std::endl;
}
else {
std::cout << "precise: "
<< "NA" << std::endl;
std::cout << "precise: " << "NA" << std::endl;
}
double recall = 0;
@ -97,8 +102,7 @@ void SvmTrain::test() {
std::cout << "recall: " << recall << std::endl;
}
else {
std::cout << "recall: "
<< "NA" << std::endl;
std::cout << "recall: " << "NA" << std::endl;
}
double Fsocre = 0;
@ -107,12 +111,13 @@ void SvmTrain::test() {
std::cout << "Fsocre: " << Fsocre << std::endl;
}
else {
std::cout << "Fsocre: "
<< "NA" << std::endl;
std::cout << "Fsocre: " << "NA" << std::endl;
}
}
void SvmTrain::prepare() {
srand(unsigned(time(NULL)));
char buffer[260] = { 0 };
sprintf(buffer, "%s/has", plates_folder_);
@ -155,7 +160,7 @@ void SvmTrain::prepare() {
// copy the rest of has_file_list to the test_file_list_
test_file_list_.reserve(has_for_test + no_for_test);
for (auto i = has_for_test; i < has_num; i++) {
train_file_list_.push_back({
test_file_list_.push_back({
has_file_list[i],
kForward
});
@ -163,7 +168,7 @@ void SvmTrain::prepare() {
// copy the rest of no_file_list to the end of the test_file_list_
for (auto i = no_for_test; i < no_num; i++) {
train_file_list_.push_back({
test_file_list_.push_back({
no_file_list[i],
kInverse
});
@ -179,21 +184,22 @@ cv::Ptr<cv::ml::TrainData> SvmTrain::tdata() {
for (auto f : train_file_list_) {
auto image = cv::imread(f.file);
if (!image.data) {
fprintf(stdout, ">> Invalid image: %s\n", f.file.c_str());
fprintf(stdout, ">> Invalid image: %s ignore.\n", f.file.c_str());
continue;
}
auto features = easypr::histeq(image);
features = features.reshape(1, 1);
cv::Mat feature;
getHistogramFeatures(image, feature);
feature = feature.reshape(1, 1);
samples.push_back(features);
samples.push_back(feature);
responses.push_back(f.label);
}
cv::Mat samples_;
cv::Mat samples_, responses_;
samples.convertTo(samples_, CV_32F);
cv::Mat(responses).reshape(0, 1).copyTo(responses_);
return cv::ml::TrainData::create(samples_, ml::SampleTypes::ROW_SAMPLE,
cv::Mat(responses));
return cv::ml::TrainData::create(samples_, cv::ml::SampleTypes::ROW_SAMPLE, responses_);
}
} // namespace easypr

Loading…
Cancel
Save