diff --git a/.idea/compiler.xml b/.idea/compiler.xml index d280c68..6aa88ff 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -6,8 +6,8 @@ - + diff --git a/README.md b/README.md index 0638e6f..de2bc6e 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ 学习1200万像素的照片物体,1000张需耗时5-7个小时。 #### 本包为性能优化而对AI算法的修改 * 本包对图像AI算法进行了修改,为应对CPU部署。 -* 卷积神经网络后的全连接层直接替换成了K均值算法进行聚类,通过卷积结果与K均值矩阵欧式距离来进行判定。 +* 卷积神经网络后的全连接层直接替换成了LVQ算法进行特征向量量化学习聚类,通过卷积结果与K均值矩阵欧式距离来进行判定。 * 物体的边框检测通过卷积后的特征向量进行多元线性回归获得,检测边框的候选区并没有使用图像分割(cpu对图像分割算法真是超慢), 而是通过Frame类让用户自定义先验图框大小和先验图框每次移动的检测步长,然后再通过多次检测的IOU来确定是否为同一物体。 * 所以添加定位模式,用户要确定Frame的大小和步长,来替代基于图像分割的候选区推荐算法。 diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index 21c7e14..84ce556 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -57,8 +57,7 @@ public class MatrixOperation { //返回两个向量之间的欧氏距离的平方 public static double getEDist(Matrix matrix1, Matrix matrix2) throws Exception { if (matrix1.isRowVector() && matrix2.isRowVector() && matrix1.getY() == matrix2.getY()) { - mathMul(matrix2, -1); - Matrix matrix = add(matrix1, matrix2); + Matrix matrix = sub(matrix1, matrix2); return getNorm(matrix); } else { throw new Exception("this matrix is not rowVector or length different"); diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index 37effec..6519382 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -82,20 +82,13 @@ public class Operation {//进行计算 if (templeConfig.isHavePosition() && tagging > 0) { border.end(myMatrix, tagging); } + //进行聚类 LVQ lvq = templeConfig.getLvq(); Matrix vector = MatrixOperation.matrixToVector(myMatrix, true); MatrixBody matrixBody = new MatrixBody(); matrixBody.setMatrix(vector); matrixBody.setId(tagging); lvq.insertMatrixBody(matrixBody); - //进行聚类 - Map kMatrixMap = templeConfig.getkMatrixMap(); - if (kMatrixMap.containsKey(tagging)) { - KMatrix kMatrix = kMatrixMap.get(tagging); - kMatrix.addMatrix(myMatrix); - } else { - throw new Exception("not find tag"); - } } } else { throw new Exception("pattern is wrong"); @@ -288,7 +281,6 @@ public class Operation {//进行计算 false, -1, matrixBack); Matrix myMatrix = matrixBack.getMatrix(); Matrix vector = MatrixOperation.matrixToVector(myMatrix, true); - return getClassificationId2(vector); } else { throw new Exception("pattern is wrong"); diff --git a/src/main/java/org/wlld/imageRecognition/border/Box.java b/src/main/java/org/wlld/imageRecognition/border/Box.java new file mode 100644 index 0000000..fa6b630 --- /dev/null +++ b/src/main/java/org/wlld/imageRecognition/border/Box.java @@ -0,0 +1,29 @@ +package org.wlld.imageRecognition.border; + +import org.wlld.MatrixTools.Matrix; + +/** + * @author lidapeng + * @description + * @date 9:11 上午 2020/2/6 + */ +public class Box { + private Matrix matrix;//特征向量 + private Matrix matrixFather;//坐标向量 + + public Matrix getMatrix() { + return matrix; + } + + public void setMatrix(Matrix matrix) { + this.matrix = matrix; + } + + public Matrix getMatrixFather() { + return matrixFather; + } + + public void setMatrixFather(Matrix matrixFather) { + this.matrixFather = matrixFather; + } +} diff --git a/src/main/java/org/wlld/imageRecognition/border/KClustering.java b/src/main/java/org/wlld/imageRecognition/border/KClustering.java index a389890..eefa921 100644 --- a/src/main/java/org/wlld/imageRecognition/border/KClustering.java +++ b/src/main/java/org/wlld/imageRecognition/border/KClustering.java @@ -12,17 +12,17 @@ import java.util.*; * @date 10:14 上午 2020/2/4 */ public class KClustering { - private List matrixList = new ArrayList<>();//聚类集合 + private List matrixList = new ArrayList<>();//聚类集合 private int length;//向量长度 private int speciesQuantity;//种类数量 private Matrix[] matrices;//均值K - private Map> clusterMap = new HashMap<>();//簇 + private Map> clusterMap = new HashMap<>();//簇 public Matrix[] getMatrices() { return matrices; } - public Map> getClusterMap() { + public Map> getClusterMap() { return clusterMap; } @@ -34,7 +34,7 @@ public class KClustering { } } - public void setMatrixList(MatrixBody matrixBody) throws Exception { + public void setMatrixList(Box matrixBody) throws Exception { if (matrixBody.getMatrix().isVector() && matrixBody.getMatrix().isRowVector()) { Matrix matrix = matrixBody.getMatrix(); if (matrixList.size() == 0) { @@ -54,7 +54,7 @@ public class KClustering { private Matrix[] averageMatrix() throws Exception { Matrix[] matrices2 = new Matrix[speciesQuantity];//待比较均值K - for (MatrixBody matrixBody : matrixList) {//遍历当前集合 + for (Box matrixBody : matrixList) {//遍历当前集合 Matrix matrix = matrixBody.getMatrix(); double min = 0; int id = 0; @@ -65,11 +65,11 @@ public class KClustering { id = i; } } - List matrixList1 = clusterMap.get(id); + List matrixList1 = clusterMap.get(id); matrixList1.add(matrixBody); } //重新计算均值 - for (Map.Entry> entry : clusterMap.entrySet()) { + for (Map.Entry> entry : clusterMap.entrySet()) { Matrix matrix = average(entry.getValue()); matrices2[entry.getKey()] = matrix; } @@ -77,30 +77,33 @@ public class KClustering { } private void clear() { - for (Map.Entry> entry : clusterMap.entrySet()) { + for (Map.Entry> entry : clusterMap.entrySet()) { entry.getValue().clear(); } } - private Matrix average(List matrixList) throws Exception {//进行矩阵均值计算 + private Matrix average(List matrixList) throws Exception {//进行矩阵均值计算 double nub = ArithUtil.div(1, matrixList.size()); - Matrix matrix = new Matrix(0, length); - for (MatrixBody matrixBody1 : matrixList) { + Matrix matrix = new Matrix(1, length); + for (Box matrixBody1 : matrixList) { matrix = MatrixOperation.add(matrix, matrixBody1.getMatrix()); } MatrixOperation.mathMul(matrix, nub); return matrix; } + public void start() throws Exception {//开始聚类 if (matrixList.size() > 1) { Random random = new Random(); for (int i = 0; i < matrices.length; i++) {//初始化均值向量 int index = random.nextInt(matrixList.size()); + //要进行深度克隆 matrices[i] = matrixList.get(index).getMatrix(); } //进行两者的比较 boolean isEqual = false; + int nub = 0; do { Matrix[] matrices2 = averageMatrix(); isEqual = equals(matrices, matrices2); @@ -108,8 +111,12 @@ public class KClustering { matrices = matrices2; clear(); } + nub++; } - while (isEqual); + while (!isEqual); + //聚类结束,进行坐标均值矩阵计算 + System.out.println("聚类循环次数:" + nub); + } else { throw new Exception("matrixList number less than 2"); } @@ -126,6 +133,9 @@ public class KClustering { break; } } + if (!isEquals) { + break; + } } return isEquals; } diff --git a/src/test/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java index 7d24525..00f4def 100644 --- a/src/test/java/org/wlld/HelloWorld.java +++ b/src/test/java/org/wlld/HelloWorld.java @@ -41,7 +41,7 @@ public class HelloWorld { templeConfig.init(StudyPattern.Accuracy_Pattern, true, 3204, 4032, 1); templeConfig.insertModel(modelParameter); Operation operation = new Operation(templeConfig); - for (int i = 1; i < 30; i++) {//faster rcnn神经网络学习 + for (int i = 1; i < 100; i++) {//faster rcnn神经网络学习 System.out.println("study==" + i); //读取本地URL地址图片,并转化成矩阵 Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png"); @@ -63,14 +63,14 @@ public class HelloWorld { // System.out.println("j===" + j); // } //测试集图片,进行识别测试 -// for (int j = 121; j < 140; j++) { -// Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); -// Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + j + ".png"); -// int rightId = operation.toSee(right); -// int wrongId = operation.toSee(wrong); -// System.out.println("该图是菜单:" + rightId); -// System.out.println("该图是桌子:" + wrongId); -// } + for (int j = 121; j < 140; j++) { + Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); + Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + j + ".png"); + int rightId = operation.toSee(right); + int wrongId = operation.toSee(wrong); + System.out.println("该图是菜单:" + rightId); + System.out.println("该图是桌子:" + wrongId); + } }