diff --git a/.idea/compiler.xml b/.idea/compiler.xml
index d280c68..6aa88ff 100644
--- a/.idea/compiler.xml
+++ b/.idea/compiler.xml
@@ -6,8 +6,8 @@
-
+
diff --git a/README.md b/README.md
index 0638e6f..de2bc6e 100644
--- a/README.md
+++ b/README.md
@@ -202,7 +202,7 @@
学习1200万像素的照片物体,1000张需耗时5-7个小时。
#### 本包为性能优化而对AI算法的修改
* 本包对图像AI算法进行了修改,为应对CPU部署。
-* 卷积神经网络后的全连接层直接替换成了K均值算法进行聚类,通过卷积结果与K均值矩阵欧式距离来进行判定。
+* 卷积神经网络后的全连接层直接替换成了LVQ算法进行特征向量量化学习聚类,通过卷积结果与K均值矩阵欧式距离来进行判定。
* 物体的边框检测通过卷积后的特征向量进行多元线性回归获得,检测边框的候选区并没有使用图像分割(cpu对图像分割算法真是超慢),
而是通过Frame类让用户自定义先验图框大小和先验图框每次移动的检测步长,然后再通过多次检测的IOU来确定是否为同一物体。
* 所以添加定位模式,用户要确定Frame的大小和步长,来替代基于图像分割的候选区推荐算法。
diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java
index 21c7e14..84ce556 100644
--- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java
+++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java
@@ -57,8 +57,7 @@ public class MatrixOperation {
//返回两个向量之间的欧氏距离的平方
public static double getEDist(Matrix matrix1, Matrix matrix2) throws Exception {
if (matrix1.isRowVector() && matrix2.isRowVector() && matrix1.getY() == matrix2.getY()) {
- mathMul(matrix2, -1);
- Matrix matrix = add(matrix1, matrix2);
+ Matrix matrix = sub(matrix1, matrix2);
return getNorm(matrix);
} else {
throw new Exception("this matrix is not rowVector or length different");
diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java
index 37effec..6519382 100644
--- a/src/main/java/org/wlld/imageRecognition/Operation.java
+++ b/src/main/java/org/wlld/imageRecognition/Operation.java
@@ -82,20 +82,13 @@ public class Operation {//进行计算
if (templeConfig.isHavePosition() && tagging > 0) {
border.end(myMatrix, tagging);
}
+ //进行聚类
LVQ lvq = templeConfig.getLvq();
Matrix vector = MatrixOperation.matrixToVector(myMatrix, true);
MatrixBody matrixBody = new MatrixBody();
matrixBody.setMatrix(vector);
matrixBody.setId(tagging);
lvq.insertMatrixBody(matrixBody);
- //进行聚类
- Map kMatrixMap = templeConfig.getkMatrixMap();
- if (kMatrixMap.containsKey(tagging)) {
- KMatrix kMatrix = kMatrixMap.get(tagging);
- kMatrix.addMatrix(myMatrix);
- } else {
- throw new Exception("not find tag");
- }
}
} else {
throw new Exception("pattern is wrong");
@@ -288,7 +281,6 @@ public class Operation {//进行计算
false, -1, matrixBack);
Matrix myMatrix = matrixBack.getMatrix();
Matrix vector = MatrixOperation.matrixToVector(myMatrix, true);
-
return getClassificationId2(vector);
} else {
throw new Exception("pattern is wrong");
diff --git a/src/main/java/org/wlld/imageRecognition/border/Box.java b/src/main/java/org/wlld/imageRecognition/border/Box.java
new file mode 100644
index 0000000..fa6b630
--- /dev/null
+++ b/src/main/java/org/wlld/imageRecognition/border/Box.java
@@ -0,0 +1,29 @@
+package org.wlld.imageRecognition.border;
+
+import org.wlld.MatrixTools.Matrix;
+
+/**
+ * @author lidapeng
+ * @description
+ * @date 9:11 上午 2020/2/6
+ */
+public class Box {
+ private Matrix matrix;//特征向量
+ private Matrix matrixFather;//坐标向量
+
+ public Matrix getMatrix() {
+ return matrix;
+ }
+
+ public void setMatrix(Matrix matrix) {
+ this.matrix = matrix;
+ }
+
+ public Matrix getMatrixFather() {
+ return matrixFather;
+ }
+
+ public void setMatrixFather(Matrix matrixFather) {
+ this.matrixFather = matrixFather;
+ }
+}
diff --git a/src/main/java/org/wlld/imageRecognition/border/KClustering.java b/src/main/java/org/wlld/imageRecognition/border/KClustering.java
index a389890..eefa921 100644
--- a/src/main/java/org/wlld/imageRecognition/border/KClustering.java
+++ b/src/main/java/org/wlld/imageRecognition/border/KClustering.java
@@ -12,17 +12,17 @@ import java.util.*;
* @date 10:14 上午 2020/2/4
*/
public class KClustering {
- private List matrixList = new ArrayList<>();//聚类集合
+ private List matrixList = new ArrayList<>();//聚类集合
private int length;//向量长度
private int speciesQuantity;//种类数量
private Matrix[] matrices;//均值K
- private Map> clusterMap = new HashMap<>();//簇
+ private Map> clusterMap = new HashMap<>();//簇
public Matrix[] getMatrices() {
return matrices;
}
- public Map> getClusterMap() {
+ public Map> getClusterMap() {
return clusterMap;
}
@@ -34,7 +34,7 @@ public class KClustering {
}
}
- public void setMatrixList(MatrixBody matrixBody) throws Exception {
+ public void setMatrixList(Box matrixBody) throws Exception {
if (matrixBody.getMatrix().isVector() && matrixBody.getMatrix().isRowVector()) {
Matrix matrix = matrixBody.getMatrix();
if (matrixList.size() == 0) {
@@ -54,7 +54,7 @@ public class KClustering {
private Matrix[] averageMatrix() throws Exception {
Matrix[] matrices2 = new Matrix[speciesQuantity];//待比较均值K
- for (MatrixBody matrixBody : matrixList) {//遍历当前集合
+ for (Box matrixBody : matrixList) {//遍历当前集合
Matrix matrix = matrixBody.getMatrix();
double min = 0;
int id = 0;
@@ -65,11 +65,11 @@ public class KClustering {
id = i;
}
}
- List matrixList1 = clusterMap.get(id);
+ List matrixList1 = clusterMap.get(id);
matrixList1.add(matrixBody);
}
//重新计算均值
- for (Map.Entry> entry : clusterMap.entrySet()) {
+ for (Map.Entry> entry : clusterMap.entrySet()) {
Matrix matrix = average(entry.getValue());
matrices2[entry.getKey()] = matrix;
}
@@ -77,30 +77,33 @@ public class KClustering {
}
private void clear() {
- for (Map.Entry> entry : clusterMap.entrySet()) {
+ for (Map.Entry> entry : clusterMap.entrySet()) {
entry.getValue().clear();
}
}
- private Matrix average(List matrixList) throws Exception {//进行矩阵均值计算
+ private Matrix average(List matrixList) throws Exception {//进行矩阵均值计算
double nub = ArithUtil.div(1, matrixList.size());
- Matrix matrix = new Matrix(0, length);
- for (MatrixBody matrixBody1 : matrixList) {
+ Matrix matrix = new Matrix(1, length);
+ for (Box matrixBody1 : matrixList) {
matrix = MatrixOperation.add(matrix, matrixBody1.getMatrix());
}
MatrixOperation.mathMul(matrix, nub);
return matrix;
}
+
public void start() throws Exception {//开始聚类
if (matrixList.size() > 1) {
Random random = new Random();
for (int i = 0; i < matrices.length; i++) {//初始化均值向量
int index = random.nextInt(matrixList.size());
+ //要进行深度克隆
matrices[i] = matrixList.get(index).getMatrix();
}
//进行两者的比较
boolean isEqual = false;
+ int nub = 0;
do {
Matrix[] matrices2 = averageMatrix();
isEqual = equals(matrices, matrices2);
@@ -108,8 +111,12 @@ public class KClustering {
matrices = matrices2;
clear();
}
+ nub++;
}
- while (isEqual);
+ while (!isEqual);
+ //聚类结束,进行坐标均值矩阵计算
+ System.out.println("聚类循环次数:" + nub);
+
} else {
throw new Exception("matrixList number less than 2");
}
@@ -126,6 +133,9 @@ public class KClustering {
break;
}
}
+ if (!isEquals) {
+ break;
+ }
}
return isEquals;
}
diff --git a/src/test/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java
index 7d24525..00f4def 100644
--- a/src/test/java/org/wlld/HelloWorld.java
+++ b/src/test/java/org/wlld/HelloWorld.java
@@ -41,7 +41,7 @@ public class HelloWorld {
templeConfig.init(StudyPattern.Accuracy_Pattern, true, 3204, 4032, 1);
templeConfig.insertModel(modelParameter);
Operation operation = new Operation(templeConfig);
- for (int i = 1; i < 30; i++) {//faster rcnn神经网络学习
+ for (int i = 1; i < 100; i++) {//faster rcnn神经网络学习
System.out.println("study==" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");
@@ -63,14 +63,14 @@ public class HelloWorld {
// System.out.println("j===" + j);
// }
//测试集图片,进行识别测试
-// for (int j = 121; j < 140; j++) {
-// Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png");
-// Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + j + ".png");
-// int rightId = operation.toSee(right);
-// int wrongId = operation.toSee(wrong);
-// System.out.println("该图是菜单:" + rightId);
-// System.out.println("该图是桌子:" + wrongId);
-// }
+ for (int j = 121; j < 140; j++) {
+ Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png");
+ Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + j + ".png");
+ int rightId = operation.toSee(right);
+ int wrongId = operation.toSee(wrong);
+ System.out.println("该图是菜单:" + rightId);
+ System.out.println("该图是桌子:" + wrongId);
+ }
}