修复KNN模型注入及获取的BUG

pull/41/head
thenk008 5 years ago
parent ca90f23abf
commit f9254bd0db

@ -17,6 +17,7 @@ import org.wlld.imageRecognition.modelEntity.LvqModel;
import org.wlld.imageRecognition.modelEntity.MatrixModel; import org.wlld.imageRecognition.modelEntity.MatrixModel;
import org.wlld.nerveCenter.NerveManager; import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveCenter.Normalization; import org.wlld.nerveCenter.Normalization;
import org.wlld.nerveEntity.BodyList;
import org.wlld.nerveEntity.ModelParameter; import org.wlld.nerveEntity.ModelParameter;
import org.wlld.nerveEntity.SensoryNerve; import org.wlld.nerveEntity.SensoryNerve;
import org.wlld.param.Cutting; import org.wlld.param.Cutting;
@ -539,15 +540,18 @@ public class TempleConfig {
case Classifier.KNN: case Classifier.KNN:
if (knn != null) { if (knn != null) {
Map<Integer, List<Matrix>> listMap = knn.getFeatureMap(); Map<Integer, List<Matrix>> listMap = knn.getFeatureMap();
Map<Integer, List<List<Double>>> knnVector = new HashMap<>(); List<BodyList> knnVector = new ArrayList<>();
for (Map.Entry<Integer, List<Matrix>> entry : listMap.entrySet()) { for (Map.Entry<Integer, List<Matrix>> entry : listMap.entrySet()) {
List<Matrix> list = entry.getValue(); List<Matrix> list = entry.getValue();
List<List<Double>> listFeature = new ArrayList<>(); List<List<Double>> listFeature = new ArrayList<>();
BodyList bodyList = new BodyList();
bodyList.setLists(listFeature);
bodyList.setType(entry.getKey());
for (Matrix matrix : list) { for (Matrix matrix : list) {
List<Double> list1 = MatrixOperation.rowVectorToList(matrix); List<Double> list1 = MatrixOperation.rowVectorToList(matrix);
listFeature.add(list1); listFeature.add(list1);
} }
knnVector.put(entry.getKey(), listFeature); knnVector.add(bodyList);
} }
modelParameter.setKnnVector(knnVector); modelParameter.setKnnVector(knnVector);
} }
@ -662,11 +666,11 @@ public class TempleConfig {
nerveManager.insertModelParameter(modelParameter); nerveManager.insertModelParameter(modelParameter);
break; break;
case Classifier.KNN: case Classifier.KNN:
Map<Integer, List<List<Double>>> knnVector = modelParameter.getKnnVector(); List<BodyList> knnVector = modelParameter.getKnnVector();
if (knn != null && knnVector != null) { if (knn != null && knnVector != null) {
for (Map.Entry<Integer, List<List<Double>>> entry : knnVector.entrySet()) { for (BodyList bodyList : knnVector) {
List<List<Double>> featureList = entry.getValue(); int type = bodyList.getType();
int type = entry.getKey(); List<List<Double>> featureList = bodyList.getLists();
for (List<Double> list : featureList) { for (List<Double> list : featureList) {
Matrix matrix = MatrixOperation.listToRowVector(list); Matrix matrix = MatrixOperation.listToRowVector(list);
knn.insertMatrix(matrix, type); knn.insertMatrix(matrix, type);

@ -0,0 +1,24 @@
package org.wlld.nerveEntity;
import java.util.List;
public class BodyList {
private int type;
private List<List<Double>> lists;
public int getType() {
return type;
}
public void setType(int type) {
this.type = type;
}
public List<List<Double>> getLists() {
return lists;
}
public void setLists(List<List<Double>> lists) {
this.lists = lists;
}
}

@ -23,15 +23,15 @@ public class ModelParameter {
private Map<Integer, KBorder> borderMap = new HashMap<>();//边框距离模型 private Map<Integer, KBorder> borderMap = new HashMap<>();//边框距离模型
private LvqModel lvqModel;//LVQ模型 private LvqModel lvqModel;//LVQ模型
private Map<Integer, List<Double>> matrixK = new HashMap<>();//均值特征向量 private Map<Integer, List<Double>> matrixK = new HashMap<>();//均值特征向量
private Map<Integer, List<List<Double>>> knnVector;//Knn模型 private List<BodyList> knnVector;//Knn模型
private Frame frame;//先验边框 private Frame frame;//先验边框
private double dnnAvg;// private double dnnAvg;//
public Map<Integer, List<List<Double>>> getKnnVector() { public List<BodyList> getKnnVector() {
return knnVector; return knnVector;
} }
public void setKnnVector(Map<Integer, List<List<Double>>> knnVector) { public void setKnnVector(List<BodyList> knnVector) {
this.knnVector = knnVector; this.knnVector = knnVector;
} }

@ -1,6 +1,7 @@
package coverTest; package coverTest;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.Matrix;
import org.wlld.config.Classifier; import org.wlld.config.Classifier;
import org.wlld.config.RZ; import org.wlld.config.RZ;
@ -19,7 +20,7 @@ import java.util.List;
public class FoodTest { public class FoodTest {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
//test(); test();
} }
public static void one(double[] test, double[] right, double[] wrong) { public static void one(double[] test, double[] right, double[] wrong) {
@ -39,7 +40,7 @@ public class FoodTest {
public static void test2(TempleConfig templeConfig) throws Exception { public static void test2(TempleConfig templeConfig) throws Exception {
if (templeConfig == null) { if (templeConfig == null) {
templeConfig = getTemple(); templeConfig = getTemple(null);
} }
Picture picture = new Picture(); Picture picture = new Picture();
List<Specifications> specificationsList = new ArrayList<>(); List<Specifications> specificationsList = new ArrayList<>();
@ -63,11 +64,11 @@ public class FoodTest {
} }
} }
public static TempleConfig getTemple() throws Exception { public static TempleConfig getTemple(ModelParameter modelParameter) throws Exception {
TempleConfig templeConfig = new TempleConfig(); TempleConfig templeConfig = new TempleConfig();
templeConfig.isShowLog(true);//是否打印日志 //templeConfig.isShowLog(true);//是否打印日志
Cutting cutting = templeConfig.getCutting(); Cutting cutting = templeConfig.getCutting();
Food food =templeConfig.getFood(); Food food = templeConfig.getFood();
//切割 //切割
cutting.setMaxRain(320);//切割阈值 cutting.setMaxRain(320);//切割阈值
cutting.setTh(0.88); cutting.setTh(0.88);
@ -83,12 +84,15 @@ public class FoodTest {
food.setTimes(2);//聚类数据增强 food.setTimes(2);//聚类数据增强
templeConfig.setClassifier(Classifier.KNN); templeConfig.setClassifier(Classifier.KNN);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3); templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3);
if (modelParameter != null) {
templeConfig.insertModel(modelParameter);
}
return templeConfig; return templeConfig;
} }
public static void test() throws Exception { public static void test() throws Exception {
Picture picture = new Picture(); Picture picture = new Picture();
TempleConfig templeConfig = getTemple(); TempleConfig templeConfig = getTemple(null);
Operation operation = new Operation(templeConfig); Operation operation = new Operation(templeConfig);
List<Specifications> specificationsList = new ArrayList<>(); List<Specifications> specificationsList = new ArrayList<>();
Specifications specifications = new Specifications(); Specifications specifications = new Specifications();
@ -120,9 +124,12 @@ public class FoodTest {
operation.colorStudy(threeChannelMatrix10, 10, specificationsList); operation.colorStudy(threeChannelMatrix10, 10, specificationsList);
System.out.println("=======================================" + i); System.out.println("=======================================" + i);
} }
ModelParameter modelParameter = templeConfig.getModel();
templeConfig.finishStudy(); String model = JSON.toJSONString(modelParameter);
test2(templeConfig); System.out.println(model);
ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class);
//templeConfig.finishStudy();
test2(getTemple(modelParameter1));
} }
public static void study() throws Exception { public static void study() throws Exception {

@ -1,8 +1,11 @@
package coverTest; package coverTest;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.Ma; import org.wlld.Ma;
import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation; import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.ModelData;
import org.wlld.config.Classifier; import org.wlld.config.Classifier;
import org.wlld.config.RZ; import org.wlld.config.RZ;
import org.wlld.config.StudyPattern; import org.wlld.config.StudyPattern;
@ -13,6 +16,7 @@ import org.wlld.imageRecognition.Operation;
import org.wlld.imageRecognition.Picture; import org.wlld.imageRecognition.Picture;
import org.wlld.imageRecognition.TempleConfig; import org.wlld.imageRecognition.TempleConfig;
import org.wlld.nerveCenter.NerveManager; import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.nerveEntity.SensoryNerve; import org.wlld.nerveEntity.SensoryNerve;
import org.wlld.tools.ArithUtil; import org.wlld.tools.ArithUtil;
@ -32,6 +36,13 @@ public class PicTest {
//testImage(right, wrong, a, b); //testImage(right, wrong, a, b);
//test(); //test();
tm();
}
public static void tm() {
String model = ModelData.DATA4;
int index = model.indexOf("knnVector");
} }
public static void test() throws Exception {//对图像进行识别测试 public static void test() throws Exception {//对图像进行识别测试

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save