From d8d6a3197c98f2add50d446fdc929508fb4da5e1 Mon Sep 17 00:00:00 2001 From: lidapeng Date: Mon, 3 Feb 2020 18:05:30 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AB=E5=B0=BE=E5=8A=A0=E4=BA=86=E4=B8=80?= =?UTF-8?q?=E5=B1=82=E6=B1=A0=E5=8C=96=EF=BC=8C=E5=87=8F=E8=BD=BB=E7=9F=A9?= =?UTF-8?q?=E9=98=B5=E6=B1=82=E9=80=86=E7=9A=84=E5=8E=8B=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/MatrixTools/MatrixOperation.java | 46 +++++++++++++++ .../org/wlld/imageRecognition/Operation.java | 3 + .../wlld/imageRecognition/border/Border.java | 2 + src/test/java/org/wlld/HelloWorld.java | 58 ++++++++++++------- src/test/java/org/wlld/MatrixTest.java | 23 +++----- 5 files changed, 95 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index 3776366..c7e2b54 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -107,6 +107,52 @@ public class MatrixOperation { } } + public static Matrix getPoolVector(Matrix matrix) throws Exception { + if (matrix.getX() == 1 || matrix.getY() == 1) { + Matrix vector; + int nub; + boolean isRow = false; + if (matrix.getX() == 1) {//行向量 + isRow = true; + nub = matrix.getY() / 4; + vector = new Matrix(1, nub); + } else {//列向量 + nub = matrix.getX() / 4; + vector = new Matrix(nub, 1); + } + int k = 0; + for (int i = 0; i < nub * 4 - 3; i += 4) { + double max = 0; + if (isRow) { + max = matrix.getNumber(0, i); + max = getMax(max, matrix.getNumber(0, i + 1)); + max = getMax(max, matrix.getNumber(0, i + 2)); + max = getMax(max, matrix.getNumber(0, i + 3)); + vector.setNub(0, k, max); + } else { + max = matrix.getNumber(i, 0); + max = getMax(max, matrix.getNumber(i + 1, 0)); + max = getMax(max, matrix.getNumber(i + 2, 0)); + max = getMax(max, matrix.getNumber(i + 3, 0)); + vector.setNub(k, 0, max); + } + k++; + } + return vector; + } else { + throw new Exception("this matrix is not a vector"); + } + + } + + private static double getMax(double o1, double o2) { + if (o1 > o2) { + return o1; + } else { + return o2; + } + } + public static Matrix matrixToVector(Matrix matrix, boolean isRow) throws Exception {//将一个矩阵转成行向量 int x = matrix.getX(); int y = matrix.getY(); diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index 4159e95..7e460d7 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -203,7 +203,10 @@ public class Operation {//进行计算 Matrix yw = borderBody.getyW(); Matrix hw = borderBody.gethW(); Matrix ww = borderBody.getwW(); + //将矩阵化为横向量 matrix = MatrixOperation.matrixToVector(matrix, true); + //最后加一层池化 + matrix = MatrixOperation.getPoolVector(matrix); //将参数矩阵的末尾填1 matrix = MatrixOperation.push(matrix, 1, true); //锚点坐标及长宽预测值 diff --git a/src/main/java/org/wlld/imageRecognition/border/Border.java b/src/main/java/org/wlld/imageRecognition/border/Border.java index 341e615..fcb12ee 100644 --- a/src/main/java/org/wlld/imageRecognition/border/Border.java +++ b/src/main/java/org/wlld/imageRecognition/border/Border.java @@ -67,6 +67,8 @@ public class Border { double th = Math.log(ArithUtil.div(height, modelHeight)); //进行参数汇集 矩阵转化为行向量 matrix = MatrixOperation.matrixToVector(matrix, true); + //最后给一层池化层 + matrix = MatrixOperation.getPoolVector(matrix); //将参数矩阵的末尾填1 matrix = MatrixOperation.push(matrix, 1, true); if (matrixX == null) {//如果是第一次直接赋值 diff --git a/src/test/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java index c1f344c..7c0ae07 100644 --- a/src/test/java/org/wlld/HelloWorld.java +++ b/src/test/java/org/wlld/HelloWorld.java @@ -8,9 +8,11 @@ import org.wlld.imageRecognition.Operation; import org.wlld.imageRecognition.Picture; import org.wlld.imageRecognition.TempleConfig; import org.wlld.imageRecognition.border.Frame; +import org.wlld.imageRecognition.border.FrameBody; import org.wlld.nerveEntity.ModelParameter; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -28,34 +30,46 @@ public class HelloWorld { public static void test() throws Exception { Picture picture = new Picture(); TempleConfig templeConfig = new TempleConfig(); - ModelParameter modelParameter = JSONObject.parseObject(ModelData.DATA2, ModelParameter.class); + templeConfig.setHavePosition(true); + Frame frame = new Frame(); + frame.setWidth(3024); + frame.setHeight(4032); + frame.setLengthHeight(100); + frame.setLengthWidth(100); + templeConfig.setFrame(frame); + ModelParameter modelParameter = JSONObject.parseObject(ModelData.DATA, ModelParameter.class); templeConfig.init(StudyPattern.Accuracy_Pattern, true, 3204, 4032, 1); templeConfig.insertModel(modelParameter); Operation operation = new Operation(templeConfig); -// for (int i = 1; i < 10; i++) {//faster rcnn神经网络学习 -// System.out.println("study==" + i); -// //读取本地URL地址图片,并转化成矩阵 -// Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png"); -// Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + i + ".png"); -// //将图像矩阵和标注加入进行学习,Accuracy_Pattern 模式 进行第二次学习 -// //第二次学习的时候,第三个参数必须是 true -// operation.learning(right, 1, true); -// operation.learning(wrong, 0, true); -// } - //templeConfig.clustering();//进行聚类 -// ModelParameter modelParameter1 = templeConfig.getModel(); -// String a = JSON.toJSONString(modelParameter1); -// System.out.println(a); - - //测试集图片,进行识别测试 + for (int i = 1; i < 300; i++) {//faster rcnn神经网络学习 + System.out.println("study==" + i); + //读取本地URL地址图片,并转化成矩阵 + Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png"); + Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + i + ".png"); + //将图像矩阵和标注加入进行学习,Accuracy_Pattern 模式 进行第二次学习 + //第二次学习的时候,第三个参数必须是 true + operation.learning(right, 1, true); + operation.learning(wrong, 0, true); + } + templeConfig.boxStudy();//边框回归 + templeConfig.clustering();//进行聚类 + ModelParameter modelParameter1 = templeConfig.getModel(); + String a = JSON.toJSONString(modelParameter1); + System.out.println(a); for (int j = 121; j < 140; j++) { Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); - Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + j + ".png"); - int rightId = operation.toSee(right); - int wrongId = operation.toSee(wrong); - System.out.println("该图是菜单:" + rightId); - System.out.println("该图是桌子:" + wrongId); + Map> map = operation.lookWithPosition(right, j); + System.out.println("j===" + j); } + //测试集图片,进行识别测试 +// for (int j = 121; j < 140; j++) { +// Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); +// Matrix wrong = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/b/b" + j + ".png"); +// int rightId = operation.toSee(right); +// int wrongId = operation.toSee(wrong); +// System.out.println("该图是菜单:" + rightId); +// System.out.println("该图是桌子:" + wrongId); +// } } diff --git a/src/test/java/org/wlld/MatrixTest.java b/src/test/java/org/wlld/MatrixTest.java index 15b6913..fbb7cf0 100644 --- a/src/test/java/org/wlld/MatrixTest.java +++ b/src/test/java/org/wlld/MatrixTest.java @@ -4,6 +4,8 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; import org.wlld.imageRecognition.border.BorderBody; +import java.util.Random; + /** * @author lidapeng * @description @@ -11,24 +13,15 @@ import org.wlld.imageRecognition.border.BorderBody; */ public class MatrixTest { public static void main(String[] args) throws Exception { - double a = 3.33333; - double b = 3; - System.out.println(a / b); - //test3(); + test4(); } public static void test4() throws Exception { - BorderBody borderBody = new BorderBody(); - Matrix xw = borderBody.getxW(); - String a = "[1]#" + - "[3]#" + - "[5]#"; - xw = new Matrix(3, 1, a); - borderBody.setxW(xw); - Matrix xt = borderBody.getxW(); - xt = MatrixOperation.push(xt, 9, false); - Matrix xm = borderBody.getxW(); - System.out.println(xm.getString()); + Matrix matrix = new Matrix(1, 12); + String a = "[1,2,3,4,5,6,7,8,9,10,11,12]#"; + matrix.setAll(a); + Matrix matrix1 = MatrixOperation.getPoolVector(matrix); + System.out.println(matrix1.getString()); } public static void test3() throws Exception {