diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index 6ac1755..21c7e14 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -212,8 +212,13 @@ public class MatrixOperation { public static double getNorm(Matrix matrix) throws Exception {//求向量范数 if (matrix.getY() == 1 || matrix.getX() == 1) { - Matrix matrix1 = transPosition(matrix); - return Math.sqrt(mulMatrix(matrix1, matrix).getNumber(0, 0)); + double nub = 0; + for (int i = 0; i < matrix.getX(); i++) { + for (int j = 0; j < matrix.getY(); j++) { + nub = ArithUtil.add(Math.pow(matrix.getNumber(i, j), 2), nub); + } + } + return Math.sqrt(nub); } else { throw new Exception("this matrix is not vector"); } diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index c61cb3c..37effec 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -5,10 +5,7 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; import org.wlld.config.StudyPattern; import org.wlld.i.OutBack; -import org.wlld.imageRecognition.border.Border; -import org.wlld.imageRecognition.border.BorderBody; -import org.wlld.imageRecognition.border.Frame; -import org.wlld.imageRecognition.border.FrameBody; +import org.wlld.imageRecognition.border.*; import org.wlld.nerveEntity.SensoryNerve; import org.wlld.tools.ArithUtil; @@ -80,10 +77,17 @@ public class Operation {//进行计算 intoNerve2(1, matrix, templeConfig.getConvolutionNerveManager().getSensoryNerves(), isKernelStudy, tagging, matrixBack); if (isNerveStudy) { + //卷积后的结果 Matrix myMatrix = matrixBack.getMatrix(); if (templeConfig.isHavePosition() && tagging > 0) { border.end(myMatrix, tagging); } + LVQ lvq = templeConfig.getLvq(); + Matrix vector = MatrixOperation.matrixToVector(myMatrix, true); + MatrixBody matrixBody = new MatrixBody(); + matrixBody.setMatrix(vector); + matrixBody.setId(tagging); + lvq.insertMatrixBody(matrixBody); //进行聚类 Map kMatrixMap = templeConfig.getkMatrixMap(); if (kMatrixMap.containsKey(tagging)) { @@ -283,12 +287,31 @@ public class Operation {//进行计算 intoNerve2(2, matrix, templeConfig.getConvolutionNerveManager().getSensoryNerves(), false, -1, matrixBack); Matrix myMatrix = matrixBack.getMatrix(); - return getClassificationId(myMatrix); + Matrix vector = MatrixOperation.matrixToVector(myMatrix, true); + + return getClassificationId2(vector); } else { throw new Exception("pattern is wrong"); } } + private int getClassificationId2(Matrix myVector) throws Exception { + int id = 0; + double distEnd = 0; + LVQ lvq = templeConfig.getLvq(); + MatrixBody[] matrixBodies = lvq.getModel(); + for (int i = 0; i < matrixBodies.length; i++) { + MatrixBody matrixBody = matrixBodies[i]; + Matrix vector = matrixBody.getMatrix(); + double dist = lvq.vectorEqual(myVector, vector); + if (distEnd == 0 || dist < distEnd) { + id = matrixBody.getId(); + distEnd = dist; + } + } + return id; + } + private int getClassificationId(Matrix myMatrix) throws Exception { int id = 0; double distEnd = 0; diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index e6213e4..1b73514 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -7,6 +7,7 @@ import org.wlld.function.ReLu; import org.wlld.function.Sigmod; import org.wlld.imageRecognition.border.BorderBody; import org.wlld.imageRecognition.border.Frame; +import org.wlld.imageRecognition.border.LVQ; import org.wlld.nerveCenter.NerveManager; import org.wlld.nerveEntity.ModelParameter; import org.wlld.nerveEntity.SensoryNerve; @@ -30,10 +31,20 @@ public class TempleConfig { private boolean isHavePosition = false;//是否需要锁定物体位置 private Map borderBodyMap = new HashMap<>();//border特征集合 private Map kMatrixMap = new HashMap<>();//K均值矩阵集合 + private LVQ lvq; private Frame frame;//先验边框 private double th = 0.6;//标准阈值 private boolean boxReady = false;//边框已经学习完毕 private double iouTh = 0.5;//IOU阈值 + private int lvqNub = 50;//lvq循环次数,默认50 + + public int getLvqNub() { + return lvqNub; + } + + public void setLvqNub(int lvqNub) { + this.lvqNub = lvqNub; + } public double getIouTh() { return iouTh; @@ -77,6 +88,14 @@ public class TempleConfig { } } + public void startLvq() throws Exception {//进行量化 + lvq.start(); + } + + public LVQ getLvq() { + return lvq; + } + private void border(BorderBody borderBody) throws Exception { Matrix parameter = borderBody.getX();//参数矩阵 Matrix tx = borderBody.getTx(); @@ -166,6 +185,7 @@ public class TempleConfig { private void initConvolutionVision(boolean initPower, int width, int height) throws Exception { int deep = 0; + lvq = new LVQ(classificationNub + 1, lvqNub); Map matrixMap = new HashMap<>();//主键与期望矩阵的映射 while (width > 5 && height > 5) { width = width / 3; diff --git a/src/main/java/org/wlld/imageRecognition/border/KClustering.java b/src/main/java/org/wlld/imageRecognition/border/KClustering.java index 8dcab79..a389890 100644 --- a/src/main/java/org/wlld/imageRecognition/border/KClustering.java +++ b/src/main/java/org/wlld/imageRecognition/border/KClustering.java @@ -15,7 +15,7 @@ public class KClustering { private List matrixList = new ArrayList<>();//聚类集合 private int length;//向量长度 private int speciesQuantity;//种类数量 - private Matrix[] matrices = new Matrix[speciesQuantity];//均值K + private Matrix[] matrices;//均值K private Map> clusterMap = new HashMap<>();//簇 public Matrix[] getMatrices() { @@ -28,6 +28,7 @@ public class KClustering { public KClustering(int speciesQuantity) { this.speciesQuantity = speciesQuantity; + matrices = new Matrix[speciesQuantity]; for (int i = 0; i < speciesQuantity; i++) { clusterMap.put(i, new ArrayList<>()); } diff --git a/src/main/java/org/wlld/imageRecognition/border/LVQ.java b/src/main/java/org/wlld/imageRecognition/border/LVQ.java index 05f4931..6f0aed1 100644 --- a/src/main/java/org/wlld/imageRecognition/border/LVQ.java +++ b/src/main/java/org/wlld/imageRecognition/border/LVQ.java @@ -2,7 +2,6 @@ package org.wlld.imageRecognition.border; import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; -import org.wlld.tools.ArithUtil; import java.util.ArrayList; import java.util.List; @@ -15,16 +14,27 @@ import java.util.Random; */ public class LVQ { private int typeNub;//原型聚类个数,即分类个数 - private MatrixBody[] model = new MatrixBody[typeNub];//原型向量 + private MatrixBody[] model;//原型向量 private List matrixList = new ArrayList<>(); private double studyPoint = 0.1;//量化学习率 private int length;//向量长度 + private boolean isReady = false; + private int lvqNub = 50; - public LVQ(int typeNub) { + public boolean isReady() { + return isReady; + } + + public LVQ(int typeNub, int lvqNub) { this.typeNub = typeNub; + this.lvqNub = lvqNub; + model = new MatrixBody[typeNub]; } - public MatrixBody[] getModel() { + public MatrixBody[] getModel() throws Exception { + if (!isReady) { + throw new Exception("not study"); + } return model; } @@ -46,31 +56,32 @@ public class LVQ { } } - private double study() throws Exception { - double error = 0; + private void study() throws Exception { for (MatrixBody matrixBody : matrixList) { Matrix matrix = matrixBody.getMatrix();//特征向量 long type = matrixBody.getId();//类别 + double distEnd = 0; + int id = 0; for (int i = 0; i < typeNub; i++) { MatrixBody modelBody = model[i]; Matrix modelMatrix = modelBody.getMatrix(); - long id = modelBody.getId(); - boolean isRight = id == type;//类别是否相同 - //对矩阵进行修正 - Matrix matrix1 = op(matrix, modelMatrix, isRight); //修正矩阵与原矩阵的范数差 - double dist = vectorEqual(modelMatrix, matrix1); - //将修正后的向量进行赋值 - modelBody.setMatrix(matrix1); - //统计变化值 - error = ArithUtil.add(error, dist); + double dist = vectorEqual(modelMatrix, matrix); + if (distEnd == 0 || dist < distEnd) { + id = matrixBody.getId(); + distEnd = dist; + } } + MatrixBody modelBody = model[id]; + Matrix modelMatrix = modelBody.getMatrix(); + boolean isRight = id == type; + Matrix matrix1 = op(matrix, modelMatrix, isRight); + modelBody.setMatrix(matrix1); } - return error; } //比较两个向量之间的范数差 - private double vectorEqual(Matrix matrix1, Matrix matrix2) throws Exception { + public double vectorEqual(Matrix matrix1, Matrix matrix2) throws Exception { Matrix matrix = MatrixOperation.sub(matrix1, matrix2); return MatrixOperation.getNorm(matrix); } @@ -100,9 +111,9 @@ public class LVQ { model[i] = matrixBody; } //初始化完成 - for (int k = 0; k < 1000; k++) { - double error = study(); - System.out.println("error==" + error); + for (int i = 0; i < lvqNub; i++) { + study(); } + isReady = true; } } diff --git a/src/test/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java index 085a649..7d24525 100644 --- a/src/test/java/org/wlld/HelloWorld.java +++ b/src/test/java/org/wlld/HelloWorld.java @@ -30,18 +30,18 @@ public class HelloWorld { public static void test() throws Exception { Picture picture = new Picture(); TempleConfig templeConfig = new TempleConfig(); - templeConfig.setHavePosition(true); - Frame frame = new Frame(); - frame.setWidth(3024); - frame.setHeight(4032); - frame.setLengthHeight(100); - frame.setLengthWidth(100); - templeConfig.setFrame(frame); + //templeConfig.setHavePosition(true); +// Frame frame = new Frame(); +// frame.setWidth(3024); +// frame.setHeight(4032); +// frame.setLengthHeight(100); +// frame.setLengthWidth(100); +// templeConfig.setFrame(frame); ModelParameter modelParameter = JSONObject.parseObject(ModelData.DATA, ModelParameter.class); templeConfig.init(StudyPattern.Accuracy_Pattern, true, 3204, 4032, 1); templeConfig.insertModel(modelParameter); Operation operation = new Operation(templeConfig); - for (int i = 1; i < 20; i++) {//faster rcnn神经网络学习 + for (int i = 1; i < 30; i++) {//faster rcnn神经网络学习 System.out.println("study==" + i); //读取本地URL地址图片,并转化成矩阵 Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png"); @@ -51,16 +51,17 @@ public class HelloWorld { operation.learning(right, 1, true); operation.learning(wrong, 0, true); } - templeConfig.boxStudy();//边框回归 - templeConfig.clustering();//进行聚类 + templeConfig.startLvq(); + //templeConfig.boxStudy();//边框回归 + //templeConfig.clustering();//进行聚类 // ModelParameter modelParameter1 = templeConfig.getModel(); // String a = JSON.toJSONString(modelParameter1); // System.out.println(a); - for (int j = 1; j < 2; j++) { - Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); - Map> map = operation.lookWithPosition(right, j); - System.out.println("j===" + j); - } +// for (int j = 1; j < 2; j++) { +// Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); +// Map> map = operation.lookWithPosition(right, j); +// System.out.println("j===" + j); +// } //测试集图片,进行识别测试 // for (int j = 121; j < 140; j++) { // Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png");