Simplify SVM training.

v1.6alpha
Micooz 10 years ago
parent d9fc6240c5
commit 7cf0392991

@ -1,48 +1,37 @@
#ifndef EASYPR_SVM_TRAIN_H
#define EASYPR_SVM_TRAIN_H
#ifndef EASYPR_TRAIN_SVMTRAIN_H_
#define EASYPR_TRAIN_SVMTRAIN_H_
#include "easypr/train/train.h"
#include <vector>
#include <opencv2/opencv.hpp>
#include "easypr/config.h"
namespace easypr {
class SvmTrain {
public:
typedef enum {
kForward = 1, // correspond to "has plate"
kInverse = 0 // correspond to "no plate"
} Label;
class SvmTrain : public ITrain {
public:
typedef struct {
std::string file;
SvmLabel label;
} TrainItem;
SvmTrain(const char* forward_data_folder, const char* inverse_data_folder);
SvmTrain(const char* plates_folder, const char* xml);
void train(bool divide = true, float divide_percentage = 0.7,
const char* out_svm_folder = kDefaultSvmPath);
virtual void train();
void runTest(const char* svm_path = kDefaultSvmPath);
virtual void test();
private:
/*
* divide images into train part and test part by percentage
*/
void divide(const char* images_folder, float percentage = 0.7);
private:
void prepare();
void getTrain();
virtual cv::Ptr<cv::ml::TrainData> tdata();
void getTest();
const char* forward_;
const char* inverse_;
// these two variables are used for cv::CvSVM::train_auto()
cv::Mat classes_;
cv::Ptr<cv::ml::TrainData> trainingData_;
// these two variables are used for cv::CvSVM::predict()
std::vector<cv::Mat> test_imgaes_;
std::vector<Label> test_labels_;
};
cv::Ptr<cv::ml::SVM> svm_;
const char* plates_folder_;
const char* svm_xml_;
std::vector<TrainItem> train_file_list_;
std::vector<TrainItem> test_file_list_;
};
}
#endif //EASYPR_SVM_TRAIN_H
#endif // EASYPR_TRAIN_SVMTRAIN_H_

File diff suppressed because it is too large Load Diff

@ -130,43 +130,15 @@ void command_line_handler(int argc, const char* argv[]) {
| SVM Training operations
| ------------------------------------------
|
| $ demo svm --[create|tag|train] options
| $ demo svm --plates=path/to/plates/ [--test] --svm=save/to/svm.xml
|
| ------------------------------------------
*/
options("h,help", "show help information");
options(",svm", "resources/model/svm.xml",
"the svm model file,"
" this option is used for '--tag'(required)"
" and '--train'(save svm model to) functions");
// create
options(
",create",
"create learn data, this function "
"will intercept (--max) raw images (--in) and preprocess into (--out)");
options("i,in", "", "where is the raw images");
options("o,out", "", "where to put the preprocessed images");
options("m,max", "5000", "how many learn data would you want to create");
// tag
options(",tag",
"tag learn data, this function "
"will find plate blocks in your images(--source) "
"as well as classify them into (--has) and (--no)");
options("s,source", "", "where is your images to be classified");
options(",has", "", "put plates in this folder");
options(",no", "", "put images without plate in this folder");
// train
options(",train",
"train given data, "
"including the forward(has plate) and the inverse(no plate).");
options(",has-plate", "", "where is the forward data");
options(",no-plate", "", "where is the inverse data");
options(",divide",
"whether divide train data into two parts by --percentage or not");
options(",percentage", "0.7",
"70% train data will be used for training,"
" others will be used for testing");
options(",test", "don't train again, run testing directly");
options(",plates", "",
"a folder contains both forward data and inverse data in the separated subfolders");
options(",svm", easypr::kDefaultSvmPath, "the svm model file");
options("t,test", "run tests in --plates");
}
options.add_subroutine("ann", "ann operation").make_usage("Usages:");
@ -259,57 +231,17 @@ void command_line_handler(int argc, const char* argv[]) {
std::cout << options("svm");
return;
}
if (parser->has("create")) {
assert(parser->has("in"));
assert(parser->has("out"));
assert(parser->has("max"));
auto in = parser->get("in")->val();
auto out = parser->get("out")->val();
auto max = parser->get("max")->as<int>();
easypr::preprocess::create_learn_data(in.c_str(), out.c_str(),
max);
}
if (parser->has("tag")) {
assert(parser->has("source"));
assert(parser->has("has"));
assert(parser->has("no"));
assert(parser->has("svm"));
auto source = parser->get("source")->val();
auto has_path = parser->get("has")->val();
auto no_path = parser->get("no")->val();
auto svm = parser->get("svm")->val();
easypr::preprocess::tag_data(source.c_str(), has_path.c_str(),
no_path.c_str(), svm.c_str());
std::cout << "Tagging finished, check out output images "
<< "and classify the wrong images manually." << std::endl;
easypr::SvmTrain svm(
parser->get("plates")->c_str(),
parser->get("svm")->c_str()
);
if (parser->has("test")) {
svm.test();
}
if (parser->has("train")) {
assert(parser->has("has-plate"));
assert(parser->has("no-plate"));
assert(parser->has("percentage"));
assert(parser->has("svm"));
auto forward_data_path = parser->get("has-plate")->val();
auto inverse_data_path = parser->get("no-plate")->val();
bool divide = parser->has("divide");
auto percentage = parser->get("percentage")->as<float>();
auto svm_model = parser->get("svm")->val();
bool test = parser->has("test");
easypr::SvmTrain svm(forward_data_path.c_str(),
inverse_data_path.c_str());
if (test) {
svm.runTest(svm_model.c_str());
}
else {
svm.train(divide, percentage, svm_model.c_str());
}
else {
svm.train();
}
})
.found("ann", [&]() {

Loading…
Cancel
Save