修复KNN模型注入及获取的BUG

pull/41/head
thenk008 5 years ago
parent ca90f23abf
commit f9254bd0db

@ -17,6 +17,7 @@ import org.wlld.imageRecognition.modelEntity.LvqModel;
import org.wlld.imageRecognition.modelEntity.MatrixModel;
import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveCenter.Normalization;
import org.wlld.nerveEntity.BodyList;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.nerveEntity.SensoryNerve;
import org.wlld.param.Cutting;
@ -539,15 +540,18 @@ public class TempleConfig {
case Classifier.KNN:
if (knn != null) {
Map<Integer, List<Matrix>> listMap = knn.getFeatureMap();
Map<Integer, List<List<Double>>> knnVector = new HashMap<>();
List<BodyList> knnVector = new ArrayList<>();
for (Map.Entry<Integer, List<Matrix>> entry : listMap.entrySet()) {
List<Matrix> list = entry.getValue();
List<List<Double>> listFeature = new ArrayList<>();
BodyList bodyList = new BodyList();
bodyList.setLists(listFeature);
bodyList.setType(entry.getKey());
for (Matrix matrix : list) {
List<Double> list1 = MatrixOperation.rowVectorToList(matrix);
listFeature.add(list1);
}
knnVector.put(entry.getKey(), listFeature);
knnVector.add(bodyList);
}
modelParameter.setKnnVector(knnVector);
}
@ -662,11 +666,11 @@ public class TempleConfig {
nerveManager.insertModelParameter(modelParameter);
break;
case Classifier.KNN:
Map<Integer, List<List<Double>>> knnVector = modelParameter.getKnnVector();
List<BodyList> knnVector = modelParameter.getKnnVector();
if (knn != null && knnVector != null) {
for (Map.Entry<Integer, List<List<Double>>> entry : knnVector.entrySet()) {
List<List<Double>> featureList = entry.getValue();
int type = entry.getKey();
for (BodyList bodyList : knnVector) {
int type = bodyList.getType();
List<List<Double>> featureList = bodyList.getLists();
for (List<Double> list : featureList) {
Matrix matrix = MatrixOperation.listToRowVector(list);
knn.insertMatrix(matrix, type);

@ -0,0 +1,24 @@
package org.wlld.nerveEntity;
import java.util.List;
public class BodyList {
private int type;
private List<List<Double>> lists;
public int getType() {
return type;
}
public void setType(int type) {
this.type = type;
}
public List<List<Double>> getLists() {
return lists;
}
public void setLists(List<List<Double>> lists) {
this.lists = lists;
}
}

@ -23,15 +23,15 @@ public class ModelParameter {
private Map<Integer, KBorder> borderMap = new HashMap<>();//边框距离模型
private LvqModel lvqModel;//LVQ模型
private Map<Integer, List<Double>> matrixK = new HashMap<>();//均值特征向量
private Map<Integer, List<List<Double>>> knnVector;//Knn模型
private List<BodyList> knnVector;//Knn模型
private Frame frame;//先验边框
private double dnnAvg;//
public Map<Integer, List<List<Double>>> getKnnVector() {
public List<BodyList> getKnnVector() {
return knnVector;
}
public void setKnnVector(Map<Integer, List<List<Double>>> knnVector) {
public void setKnnVector(List<BodyList> knnVector) {
this.knnVector = knnVector;
}

@ -1,6 +1,7 @@
package coverTest;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix;
import org.wlld.config.Classifier;
import org.wlld.config.RZ;
@ -19,7 +20,7 @@ import java.util.List;
public class FoodTest {
public static void main(String[] args) throws Exception {
//test();
test();
}
public static void one(double[] test, double[] right, double[] wrong) {
@ -39,7 +40,7 @@ public class FoodTest {
public static void test2(TempleConfig templeConfig) throws Exception {
if (templeConfig == null) {
templeConfig = getTemple();
templeConfig = getTemple(null);
}
Picture picture = new Picture();
List<Specifications> specificationsList = new ArrayList<>();
@ -63,11 +64,11 @@ public class FoodTest {
}
}
public static TempleConfig getTemple() throws Exception {
public static TempleConfig getTemple(ModelParameter modelParameter) throws Exception {
TempleConfig templeConfig = new TempleConfig();
templeConfig.isShowLog(true);//是否打印日志
//templeConfig.isShowLog(true);//是否打印日志
Cutting cutting = templeConfig.getCutting();
Food food =templeConfig.getFood();
Food food = templeConfig.getFood();
//切割
cutting.setMaxRain(320);//切割阈值
cutting.setTh(0.88);
@ -83,12 +84,15 @@ public class FoodTest {
food.setTimes(2);//聚类数据增强
templeConfig.setClassifier(Classifier.KNN);
templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3);
if (modelParameter != null) {
templeConfig.insertModel(modelParameter);
}
return templeConfig;
}
public static void test() throws Exception {
Picture picture = new Picture();
TempleConfig templeConfig = getTemple();
TempleConfig templeConfig = getTemple(null);
Operation operation = new Operation(templeConfig);
List<Specifications> specificationsList = new ArrayList<>();
Specifications specifications = new Specifications();
@ -120,9 +124,12 @@ public class FoodTest {
operation.colorStudy(threeChannelMatrix10, 10, specificationsList);
System.out.println("=======================================" + i);
}
templeConfig.finishStudy();
test2(templeConfig);
ModelParameter modelParameter = templeConfig.getModel();
String model = JSON.toJSONString(modelParameter);
System.out.println(model);
ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class);
//templeConfig.finishStudy();
test2(getTemple(modelParameter1));
}
public static void study() throws Exception {

@ -1,8 +1,11 @@
package coverTest;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.Ma;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.ModelData;
import org.wlld.config.Classifier;
import org.wlld.config.RZ;
import org.wlld.config.StudyPattern;
@ -13,6 +16,7 @@ import org.wlld.imageRecognition.Operation;
import org.wlld.imageRecognition.Picture;
import org.wlld.imageRecognition.TempleConfig;
import org.wlld.nerveCenter.NerveManager;
import org.wlld.nerveEntity.ModelParameter;
import org.wlld.nerveEntity.SensoryNerve;
import org.wlld.tools.ArithUtil;
@ -32,6 +36,13 @@ public class PicTest {
//testImage(right, wrong, a, b);
//test();
tm();
}
public static void tm() {
String model = ModelData.DATA4;
int index = model.indexOf("knnVector");
}
public static void test() throws Exception {//对图像进行识别测试

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save