From 003804ec222edbc585614357150bf52cb0deca30 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Thu, 15 Oct 2020 17:04:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=89=B9=E5=BE=81=E9=80=89?= =?UTF-8?q?=E5=8C=BA=E7=AD=9B=E9=80=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 2 +- .../org/wlld/MatrixTools/MatrixOperation.java | 19 +- src/main/java/org/wlld/config/Kernel.java | 4 + src/main/java/org/wlld/function/Tanh.java | 8 +- .../wlld/imageRecognition/Convolution.java | 70 ++--- .../wlld/imageRecognition/MeanClustering.java | 95 ++++-- .../org/wlld/imageRecognition/Operation.java | 10 +- .../org/wlld/imageRecognition/RGBNorm.java | 27 +- .../imageRecognition/modelEntity/RgbBack.java | 37 +++ .../segmentation/FindMaxSimilar.java | 138 --------- .../segmentation/KNerveManger.java | 117 ++++++++ .../segmentation/RgbRegression.java | 99 +++++- .../segmentation/WFilter.java | 55 ---- .../segmentation/Watershed.java | 20 +- src/main/java/org/wlld/nerveEntity/Nerve.java | 32 +- .../java/org/wlld/nerveEntity/SoftMax.java | 8 +- src/main/java/org/wlld/param/Food.java | 19 ++ src/test/java/coverTest/DataObservation.java | 101 +++++++ src/test/java/coverTest/FoodTest.java | 107 ++++--- src/test/java/coverTest/ForestTest.java | 5 +- src/test/java/coverTest/PicTest.java | 2 +- src/test/java/coverTest/RGBBody.java | 46 +++ .../java/coverTest/regionCut/RegionCut.java | 284 ++++++++++++++++++ .../coverTest/regionCut/RegionCutBody.java | 15 + .../coverTest/regionCut/RegionFeature.java | 84 ++++++ src/test/java/org/wlld/NerveDemo1.java | 4 +- 26 files changed, 1053 insertions(+), 355 deletions(-) create mode 100644 src/main/java/org/wlld/imageRecognition/modelEntity/RgbBack.java delete mode 100644 src/main/java/org/wlld/imageRecognition/segmentation/FindMaxSimilar.java create mode 100644 src/main/java/org/wlld/imageRecognition/segmentation/KNerveManger.java delete mode 100644 src/main/java/org/wlld/imageRecognition/segmentation/WFilter.java create mode 100644 src/test/java/coverTest/DataObservation.java create mode 100644 src/test/java/coverTest/RGBBody.java create mode 100644 src/test/java/coverTest/regionCut/RegionCut.java create mode 100644 src/test/java/coverTest/regionCut/RegionCutBody.java create mode 100644 src/test/java/coverTest/regionCut/RegionFeature.java diff --git a/pom.xml b/pom.xml index 4fc633a..34cef63 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.wlld easyAi - 1.0.9 + 1.0.0 easyAi diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index 641b31b..0bc6765 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -51,12 +51,16 @@ public class MatrixOperation { Matrix matrix1 = transPosition(parameter); //转置的参数矩阵乘以参数矩阵 Matrix matrix2 = mulMatrix(matrix1, parameter); - //求上一步的逆矩阵 + //求上一步的逆矩阵 这一步需要矩阵非奇异,若出现奇异矩阵,则返回0矩阵,意味失败 Matrix matrix3 = getInverseMatrixs(matrix2); - //逆矩阵乘以转置矩阵 - Matrix matrix4 = mulMatrix(matrix3, matrix1); - //最后乘以输出矩阵,生成权重矩阵并返回 - return mulMatrix(matrix4, out); + if (matrix3.getX() == 1 && matrix3.getY() == 1) { + return matrix3; + } else { + //逆矩阵乘以转置矩阵 + Matrix matrix4 = mulMatrix(matrix3, matrix1); + //最后乘以输出矩阵,生成权重矩阵并返回 + return mulMatrix(matrix4, out); + } } else { throw new Exception("invalid regression matrix"); } @@ -347,8 +351,9 @@ public class MatrixOperation { mathMul(myMatrix, def); return myMatrix; } else { - System.out.println(matrix.getString()); - throw new Exception("this matrix do not have InverseMatrixs"); + //System.out.println("matrix def is zero error:"); + //System.out.println(matrix.getString()); + return new Matrix(1, 1); } } diff --git a/src/main/java/org/wlld/config/Kernel.java b/src/main/java/org/wlld/config/Kernel.java index 465295b..c0cf7b5 100644 --- a/src/main/java/org/wlld/config/Kernel.java +++ b/src/main/java/org/wlld/config/Kernel.java @@ -8,6 +8,8 @@ public class Kernel { private static final String Horizontal_Number = "[-1,-2,-1]#[0,0,0]#[1,2,1]#";//横卷积核 private static final String All_Number = "[1,-2,1]#[-2,4,-2]#[1,-2,1]#";//角卷积 private static final String All_Number2 = "[-1,0,-1]#[0,4,0]#[-1,0,-1]#"; + private static final String All_Big = "[-1,0,0,0,-1]#[0,-1,0,-1,0]#[0,0,8,0,0]#" + + "[0,-1,0,-1,0]#[-1,0,0,0,-1]"; public static final int Region_Nub = 60;//一张图有多少份 public static final double th = 0.88;//分水岭灰度阈值 public static final double rgbN = 442.0;//442.0;//RGB范数归一化最大值 @@ -15,6 +17,7 @@ public class Kernel { public static Matrix Horizontal; public static Matrix All; public static Matrix ALL_Two; + public static Matrix Big; public static final int Unit = 100; public static final double Pi = ArithUtil.div(ArithUtil.div(Math.PI, 2), Unit); @@ -23,6 +26,7 @@ public class Kernel { static { try { + Big = new Matrix(5, 5, All_Big); ALL_Two = new Matrix(3, 3, All_Number2); All = new Matrix(3, 3, All_Number); Vertical = new Matrix(3, 3, Vertical_Number); diff --git a/src/main/java/org/wlld/function/Tanh.java b/src/main/java/org/wlld/function/Tanh.java index 13b3cc1..776c3e7 100644 --- a/src/main/java/org/wlld/function/Tanh.java +++ b/src/main/java/org/wlld/function/Tanh.java @@ -8,13 +8,13 @@ public class Tanh implements ActiveFunction { public double function(double x) { double x1 = Math.exp(x); double x2 = Math.exp(-x); - double son = ArithUtil.sub(x1, x2); - double mother = ArithUtil.add(x1, x2); - return ArithUtil.div(son, mother); + double son = x1 - x2;// ArithUtil.sub(x1, x2); + double mother = x1 + x2;// ArithUtil.add(x1, x2); + return son / mother;//ArithUtil.div(son, mother); } @Override public double functionG(double out) { - return ArithUtil.sub(1, Math.pow(function(out), 2)); + return 1 - Math.pow(function(out), 2);//ArithUtil.sub(1, Math.pow(function(out), 2)); } } diff --git a/src/main/java/org/wlld/imageRecognition/Convolution.java b/src/main/java/org/wlld/imageRecognition/Convolution.java index 2e1c468..c7b14be 100644 --- a/src/main/java/org/wlld/imageRecognition/Convolution.java +++ b/src/main/java/org/wlld/imageRecognition/Convolution.java @@ -6,8 +6,6 @@ import org.wlld.config.Kernel; import org.wlld.imageRecognition.border.Border; import org.wlld.imageRecognition.border.Frame; import org.wlld.imageRecognition.border.FrameBody; -import org.wlld.imageRecognition.modelEntity.RegressionBody; -import org.wlld.imageRecognition.segmentation.RgbRegression; import org.wlld.tools.ArithUtil; import org.wlld.tools.Frequency; @@ -70,7 +68,7 @@ public class Convolution extends Frequency { List threeChannelMatrixList = regionThreeChannelMatrix(threeMatrix, regionSize); for (ThreeChannelMatrix threeChannelMatrix : threeChannelMatrixList) { List feature = new ArrayList<>(); - MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig); + MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig, true); Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixB = threeChannelMatrix.getMatrixB(); @@ -175,7 +173,7 @@ public class Convolution extends Frequency { RGBSort rgbSort = new RGBSort(); int x = matrixR.getX(); int y = matrixR.getY(); - MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig); + MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig, true); for (int i = 0; i < x; i++) { for (int j = 0; j < y; j++) { double[] color = new double[]{matrixR.getNumber(i, j), matrixG.getNumber(i, j), matrixB.getNumber(i, j)}; @@ -201,15 +199,13 @@ public class Convolution extends Frequency { } public List getCenterTexture(ThreeChannelMatrix threeChannelMatrix, int size, int poolSize, TempleConfig templeConfig - , int sqNub) throws Exception { + , int sqNub, int tag) throws Exception { RGBSort rgbSort = new RGBSort(); - double dispersedThNub = templeConfig.getFood().getDispersedTh(); - int step = templeConfig.getFood().getStep(); - MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig); + MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig, true); Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixB = threeChannelMatrix.getMatrixB(); - Matrix matrixH = threeChannelMatrix.getH(); + Matrix matrixRGB = threeChannelMatrix.getMatrixRGB(); int xn = matrixR.getX(); int yn = matrixR.getY(); // for (int i = 0; i < xn; i++) { @@ -220,47 +216,41 @@ public class Convolution extends Frequency { // } // } //局部特征选区筛选 - double sigma = 0; - int nub = 0; - for (int i = 0; i <= xn - size; i += step) { - for (int j = 0; j <= yn - size; j += step) { - Matrix sonH = matrixH.getSonOfMatrix(i, j, size, size); - double[] h = new double[size * size]; - nub++; - for (int t = 0; t < size; t++) { - for (int k = 0; k < size; k++) { - int index = t * size + k; - h[index] = sonH.getNumber(t, k); - } - } - sigma = dc(h) + sigma; - } - } - double dispersedTh = (sigma / nub) * dispersedThNub;//离散阈值 - for (int i = 0; i <= xn - size; i += step) { - for (int j = 0; j <= yn - size; j += step) { + int nub = size * size; + int twoNub = nub * 2; + for (int i = 0; i <= xn - size; i++) { + for (int j = 0; j <= yn - size; j++) { Matrix sonR = matrixR.getSonOfMatrix(i, j, size, size); Matrix sonG = matrixG.getSonOfMatrix(i, j, size, size); Matrix sonB = matrixB.getSonOfMatrix(i, j, size, size); - Matrix sonH = matrixH.getSonOfMatrix(i, j, size, size); - double[] h = new double[size * size]; - double[] rgb = new double[size * size * 3]; + Matrix sonRGB = matrixRGB.getSonOfMatrix(i, j, size, size); + double[] h = new double[nub]; + double[] rgb = new double[nub * 3]; for (int t = 0; t < size; t++) { for (int k = 0; k < size; k++) { int index = t * size + k; - h[index] = sonH.getNumber(t, k); - rgb[index] = sonR.getNumber(t, k); - rgb[size * size + index] = sonG.getNumber(t, k); - rgb[size * size * 2 + index] = sonB.getNumber(t, k); + h[index] = sonRGB.getNumber(t, k); + rgb[index] = sonR.getNumber(t, k) / 255; + rgb[nub + index] = sonG.getNumber(t, k) / 255; + rgb[twoNub + index] = sonB.getNumber(t, k) / 255; } } - double dispersed = dc(h); - if (dispersed < dispersedTh) { - meanClustering.setColor(rgb); + double dispersed = variance(h); + if (dispersed < 900 && dispersed > 200) { + for (int m = 0; m < nub; m++) { + double[] color = new double[]{rgb[m], rgb[m + nub], rgb[m + twoNub]}; + meanClustering.setColor(color); + } + // meanClustering.setColor(rgb); } } } - meanClustering.start(false);//开始聚类 + List list = meanClustering.start(true);//开始聚类 + if (tag == 0) {//识别 + templeConfig.getFood().getkNerveManger().look(list); + } else {//训练 + templeConfig.getFood().getkNerveManger().setFeature(tag, list); + } List rgbNorms = meanClustering.getMatrices(); Collections.sort(rgbNorms, rgbSort); List features = new ArrayList<>(); @@ -270,7 +260,7 @@ public class Convolution extends Frequency { features.add(rgb[j]); } } - //System.out.println(features); + // System.out.println(features); return features; } diff --git a/src/main/java/org/wlld/imageRecognition/MeanClustering.java b/src/main/java/org/wlld/imageRecognition/MeanClustering.java index 8e51792..4b42e41 100644 --- a/src/main/java/org/wlld/imageRecognition/MeanClustering.java +++ b/src/main/java/org/wlld/imageRecognition/MeanClustering.java @@ -1,6 +1,6 @@ package org.wlld.imageRecognition; -import org.wlld.imageRecognition.segmentation.RgbRegression; +import org.wlld.param.Food; import java.util.*; @@ -11,20 +11,31 @@ public class MeanClustering { private int speciesQuantity;//种类数量(模型需要返回) private List matrices = new ArrayList<>();//均值K模型(模型需要返回) private int size = 10000; + private TempleConfig templeConfig; + private int sensoryNerveNub;//神经元个数 + private List kList = new ArrayList<>(); public List getMatrices() { return matrices; } - public MeanClustering(int speciesQuantity, TempleConfig templeConfig) { + public MeanClustering(int speciesQuantity, TempleConfig templeConfig, boolean isFirst) throws Exception { this.speciesQuantity = speciesQuantity;//聚类的数量 - size = templeConfig.getFood().getRegressionNub(); + Food food = templeConfig.getFood(); + size = food.getRegressionNub(); + this.templeConfig = templeConfig; +// if (isFirst) { +// for (int i = 0; i < speciesQuantity; i++) { +// kList.add(new MeanClustering(10, templeConfig, false)); +// } +// } } public void setColor(double[] color) throws Exception { if (matrixList.size() == 0) { matrixList.add(color); length = color.length; + sensoryNerveNub = templeConfig.getFeatureNub() * length; } else { if (length == color.length) { matrixList.add(color); @@ -73,24 +84,60 @@ public class MeanClustering { } } - private void startRegression() throws Exception {//开始聚类回归 + private List startBp() { + int times = 2000 + 1; + List features = new ArrayList<>(); + List> lists = new ArrayList<>(); + for (int j = 0; j < matrices.size(); j++) { + List list = matrices.get(j).getRgbs().subList(0, times); + lists.add(list); + } + for (int j = 0; j < times; j++) { + double[] feature = new double[sensoryNerveNub]; + for (int i = 0; i < lists.size(); i++) { + double[] data = lists.get(i).get(j); + int len = data.length; + for (int k = 0; k < len; k++) { + feature[i * len + k] = data[k]; + } + } + features.add(feature); + } + return features; + } + + private List startRegression() throws Exception {//开始聚类回归 + for (int i = 0; i < matrices.size(); i++) { + List list = matrices.get(i).getRgbs(); + MeanClustering k = kList.get(i); + for (double[] rgb : list) { + k.setColor(rgb); + } + k.start(false); + } + //遍历子聚类 + int times = 2000; Random random = new Random(); - for (RGBNorm rgbNorm : matrices) { - RgbRegression rgbRegression = new RgbRegression(size); - List list = rgbNorm.getRgbs(); - for (int i = 0; i < size; i++) { - double[] rgb = list.get(random.nextInt(list.size())); - rgb[0] = rgb[0] / 255; - rgb[1] = rgb[1] / 255; - rgb[2] = rgb[2] / 255; - rgbRegression.insertRGB(rgb); + List features = new ArrayList<>(); + for (int i = 0; i < times; i++) { + double[] feature = new double[sensoryNerveNub]; + for (int k = 0; k < kList.size(); k++) { + MeanClustering mean = kList.get(k); + List rgbNorms = mean.getMatrices(); + double[] rgb = rgbNorms.get(random.nextInt(rgbNorms.size())).getRgb(); + int rgbLen = rgb.length; + for (int t = 0; t < rgbLen; t++) { + int index = k * rgbLen + t; + feature[index] = rgb[t]; + } } - rgbRegression.regression(); - rgbNorm.setRgbRegression(rgbRegression); + //System.out.println(Arrays.toString(feature)); + features.add(feature); } + return features; } - public void start(boolean isRegression) throws Exception {//开始聚类 + public List start(boolean isRegression) throws Exception {//开始聚类 if (matrixList.size() > 1) { Random random = new Random(); for (int i = 0; i < speciesQuantity; i++) {//初始化均值向量 @@ -102,18 +149,26 @@ public class MeanClustering { } //进行两者的比较 boolean isNext; - for (int i = 0; i < 30; i++) { + for (int i = 0; i < 40; i++) { averageMatrix(); isNext = isNext(); - if (isNext && i < 29) { + if (isNext && i < 39) { clear(); } else { break; } } - if (isRegression) { - startRegression();//开始进行回归 + RGBSort rgbSort = new RGBSort(); + Collections.sort(matrices, rgbSort); + for (RGBNorm rgbNorm : matrices) { + rgbNorm.finish(); } +// if (isRegression) { +// return startRegression(); +// } else { +// return null; +// } + return startBp(); } else { throw new Exception("matrixList number less than 2"); } diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index 1918278..aa9d8ce 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -92,7 +92,8 @@ public class Operation {//进行计算 threeChannelMatrix.setMatrixB(matrixB.getSonOfMatrix(x, y, xSize, ySize)); } - public RegionBody colorStudy(ThreeChannelMatrix threeChannelMatrix, int tag, List specificationsList) throws Exception { + public RegionBody colorStudy(ThreeChannelMatrix threeChannelMatrix, int tag, List specificationsList + , String url) throws Exception { Watershed watershed = new Watershed(threeChannelMatrix, specificationsList, templeConfig); List regionBodies = watershed.rainfall(); if (regionBodies.size() == 1) { @@ -108,7 +109,7 @@ public class Operation {//进行计算 // List feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), // templeConfig.getFeatureNub(), templeConfig); List feature = convolution.getCenterTexture(threeChannelMatrix1, templeConfig.getFood().getRegionSize(), - templeConfig.getPoolSize(), templeConfig, templeConfig.getFeatureNub()); + templeConfig.getPoolSize(), templeConfig, templeConfig.getFeatureNub(), tag); if (templeConfig.isShowLog()) { System.out.println(tag + ":" + feature); } @@ -144,7 +145,8 @@ public class Operation {//进行计算 int minY = regionBody.getMinY(); int maxX = regionBody.getMaxX(); int maxY = regionBody.getMaxY(); - System.out.println("异常:minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); + System.out.println("异常:minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY + ",tag==" + tag + + "url==" + url); } throw new Exception("Parameter exception region size==" + regionBodies.size()); } @@ -183,7 +185,7 @@ public class Operation {//进行计算 // List feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), // templeConfig.getFeatureNub(), templeConfig); List feature = convolution.getCenterTexture(threeChannelMatrix1, templeConfig.getFood().getRegionSize(), - templeConfig.getPoolSize(), templeConfig, templeConfig.getFeatureNub()); + templeConfig.getPoolSize(), templeConfig, templeConfig.getFeatureNub(), 0); if (templeConfig.isShowLog()) { System.out.println(feature); } diff --git a/src/main/java/org/wlld/imageRecognition/RGBNorm.java b/src/main/java/org/wlld/imageRecognition/RGBNorm.java index fef6cd0..0f8d88f 100644 --- a/src/main/java/org/wlld/imageRecognition/RGBNorm.java +++ b/src/main/java/org/wlld/imageRecognition/RGBNorm.java @@ -4,6 +4,8 @@ import org.wlld.imageRecognition.segmentation.RgbRegression; import org.wlld.tools.ArithUtil; import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.List; public class RGBNorm { @@ -12,7 +14,7 @@ public class RGBNorm { private int nub; private double[] rgb; private double[] rgbUp; - private List rgbs = new ArrayList<>(); + private List rgbs = new ArrayList<>();//需要对它进行排序 private RgbRegression rgbRegression; private int len; @@ -96,6 +98,11 @@ public class RGBNorm { } } + public void finish() {//进行排序 + RGBListSort rgbListSort = new RGBListSort(); + Collections.sort(rgbs, rgbListSort); + } + public double getNorm() { return norm; } @@ -103,4 +110,22 @@ public class RGBNorm { public double[] getRgb() { return rgb; } + + class RGBListSort implements Comparator { + @Override + public int compare(double[] o1, double[] o2) { + double o1Norm = 0; + double o2Norm = 0; + for (int i = 0; i < o1.length; i++) { + o1Norm = o1Norm + Math.pow(o1[i], 2); + o2Norm = o2Norm + Math.pow(o2[i], 2); + } + if (o1Norm > o2Norm) { + return -1; + } else if (o1Norm < o2Norm) { + return 1; + } + return 0; + } + } } diff --git a/src/main/java/org/wlld/imageRecognition/modelEntity/RgbBack.java b/src/main/java/org/wlld/imageRecognition/modelEntity/RgbBack.java new file mode 100644 index 0000000..7ef5b5f --- /dev/null +++ b/src/main/java/org/wlld/imageRecognition/modelEntity/RgbBack.java @@ -0,0 +1,37 @@ +package org.wlld.imageRecognition.modelEntity; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.i.OutBack; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description + */ +public class RgbBack implements OutBack { + private int id = 0; + private double out = 0; + + public void clear() { + out = 0; + id = 0; + } + + @Override + public void getBack(double out, int id, long eventId) { + if (out > this.out) { + this.out = out; + this.id = id; + } + } + + @Override + public void getBackMatrix(Matrix matrix, long eventId) { + + } + + public int getId() { + return id; + } +} diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/FindMaxSimilar.java b/src/main/java/org/wlld/imageRecognition/segmentation/FindMaxSimilar.java deleted file mode 100644 index 7c38640..0000000 --- a/src/main/java/org/wlld/imageRecognition/segmentation/FindMaxSimilar.java +++ /dev/null @@ -1,138 +0,0 @@ -package org.wlld.imageRecognition.segmentation; - -import org.wlld.MatrixTools.Matrix; -import org.wlld.MatrixTools.MatrixOperation; -import org.wlld.imageRecognition.ThreeChannelMatrix; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * @param - * @DATA - * @Author LiDaPeng - * @Description 寻找相似度最大的选区 - */ -public class FindMaxSimilar { - - public void findMaxSimilar(ThreeChannelMatrix threeChannelMatrix, int size) throws Exception { - Map threeChannelMatrices = new HashMap<>(); - Matrix matrixR = threeChannelMatrix.getMatrixR(); - Matrix matrixG = threeChannelMatrix.getMatrixG(); - Matrix matrixB = threeChannelMatrix.getMatrixB(); - Matrix matrixRGB = threeChannelMatrix.getMatrixRGB(); - int x = matrixR.getX(); - int y = matrixR.getY(); - System.out.println("初始区域数量:" + (x * y)); - int index = 0; - for (int i = 0; i <= x - size; i += size) { - for (int j = 0; j <= y - size; j += size) { - ThreeChannelMatrix threeChannelMatrix1 = new ThreeChannelMatrix(); - Matrix bodyR = matrixR.getSonOfMatrix(i, j, size, size); - Matrix bodyG = matrixG.getSonOfMatrix(i, j, size, size); - Matrix bodyB = matrixB.getSonOfMatrix(i, j, size, size); - Matrix bodyRGB = matrixRGB.getSonOfMatrix(i, j, size, size); - threeChannelMatrix1.setMatrixR(bodyR); - threeChannelMatrix1.setMatrixG(bodyG); - threeChannelMatrix1.setMatrixB(bodyB); - threeChannelMatrix1.setMatrixRGB(bodyRGB); - threeChannelMatrices.put(index, threeChannelMatrix1); - index++; - } - } - //切割完毕 ,开始寻找最大相似性 - for (Map.Entry entry : threeChannelMatrices.entrySet()) { - int key = entry.getKey(); - ThreeChannelMatrix threeChannelMatrix1 = entry.getValue(); - Matrix matrix1 = threeChannelMatrix1.getMatrixRGB(); - double minDist = -1; - int similarId = 0; - for (Map.Entry entrySon : threeChannelMatrices.entrySet()) { - int sonKey = entrySon.getKey(); - if (key != sonKey) { - Matrix matrix2 = entrySon.getValue().getMatrixRGB(); - double dist = MatrixOperation.getEDistByMatrix(matrix1, matrix2); - if (minDist == -1 || dist < minDist) { - minDist = dist; - similarId = sonKey; - } - } - } - threeChannelMatrix1.setSimilarId(similarId); - } - //最大相似性区域已经查找完毕,开始进行连线 - Map> lineMap = line(threeChannelMatrices); - //System.out.println("size:" + lineMap.size()); - int max = 0; - int maxK = 0; - for (Map.Entry> entry : lineMap.entrySet()) { - int key = entry.getKey(); - int nub = entry.getValue().size(); - if (nub > max) { - max = nub; - maxK = key; - } - } - System.out.println("max:" + max); - int key = lineMap.get(maxK).get(0); - Matrix matrix = threeChannelMatrices.get(key).getMatrixR(); - System.out.println(matrix.getString()); - - - } - - private void merge(Map> lineMap, Map map, - int x, int y) throws Exception { - Map map2 = new HashMap<>(); - for (Map.Entry> entry : lineMap.entrySet()) { - int key = entry.getKey(); - Matrix matrixAll = new Matrix(x, y); - List list = entry.getValue(); - int len = list.size(); - for (int i = 0; i < len; i++) { - ThreeChannelMatrix threeChannelMatrix = map.get(list.get(i)); - Matrix matrixRGB = threeChannelMatrix.getMatrixRGB(); - matrixAll = MatrixOperation.add(matrixAll, matrixRGB); - } - MatrixOperation.mathDiv(matrixAll, len);//矩阵数除 - map2.put(key, matrixAll); - } - - } - - private Map> line(Map threeChannelMatrices) {//开始进行连线 - Map> lineMap = new HashMap<>(); - for (Map.Entry entry : threeChannelMatrices.entrySet()) { - int key = entry.getKey();//当前进行连线的id - ThreeChannelMatrix myThreeChannelMatrix = entry.getValue(); - boolean isLine = myThreeChannelMatrix.isLine();//是否被连线 - int upIndex = key;//上一个进行连线的id - if (!isLine) {//可以进行连线 - List list = new ArrayList<>(); - lineMap.put(key, list); - list.add(key); - myThreeChannelMatrix.setLine(true); - boolean line; - do { - int similarId = myThreeChannelMatrix.getSimilarId();//距离它最近的id - ThreeChannelMatrix threeChannelMatrix = threeChannelMatrices.get(similarId); - line = threeChannelMatrix.isLine(); - if (!line) {//进行连线 - list.add(similarId); - threeChannelMatrix.setLine(true); - //如果当前被连线的矩阵的最近id为连线者本身,则连线后停止遍历 - if (upIndex == threeChannelMatrix.getSimilarId()) {//停止连线 - line = true; - } else {//继续连线z - upIndex = similarId; - myThreeChannelMatrix = threeChannelMatrix; - } - } - } while (!line); - } - } - return lineMap; - } -} diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/KNerveManger.java b/src/main/java/org/wlld/imageRecognition/segmentation/KNerveManger.java new file mode 100644 index 0000000..be1d0f2 --- /dev/null +++ b/src/main/java/org/wlld/imageRecognition/segmentation/KNerveManger.java @@ -0,0 +1,117 @@ +package org.wlld.imageRecognition.segmentation; + +import org.wlld.config.RZ; +import org.wlld.function.Sigmod; +import org.wlld.function.Tanh; +import org.wlld.imageRecognition.modelEntity.RgbBack; +import org.wlld.nerveCenter.NerveManager; +import org.wlld.nerveEntity.SensoryNerve; + +import java.awt.image.Kernel; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description + */ +public class KNerveManger { + private Map> featureMap = new HashMap<>(); + private int sensoryNerveNub;//输出神经元个数 + private int speciesNub;//种类数 + private NerveManager nerveManager; + private int times; + private RgbBack rgbBack = new RgbBack(); + + public KNerveManger(int sensoryNerveNub, int speciesNub, int times) throws Exception { + this.sensoryNerveNub = sensoryNerveNub; + this.speciesNub = speciesNub; + this.times = times; + nerveManager = new NerveManager(sensoryNerveNub, 24, speciesNub, + 1, new Tanh(),//0.008 l1 0.02 + false, false, 0.008, RZ.L1, 0.01); + nerveManager.init(true, false, true, true); + } + + private Map createTag(int tag) {//创建一个标注 + Map tagging = new HashMap<>(); + Set set = featureMap.keySet(); + for (int key : set) { + double value = 0.0; + if (key == tag) { + value = 1.0; + } + tagging.put(key, value); + } + return tagging; + } + + public void look(List data) throws Exception { + int size = data.size(); + Map map = new HashMap<>(); + for (int i = 0; i < size; i++) { + rgbBack.clear(); + post(data.get(i), null, false); + int type = rgbBack.getId(); + if (map.containsKey(type)) { + map.put(type, map.get(type) + 1); + } else { + map.put(type, 1); + } + } + double max = 0; + int type = 0; + for (Map.Entry entry : map.entrySet()) { + int nub = entry.getValue(); + if (nub > max) { + max = nub; + type = entry.getKey(); + } + } + double point = max / size; + System.out.println("类型是:" + type + ",总票数:" + size + ",得票率:" + point); + System.out.println("=================================完成"); + } + + public void startStudy() throws Exception { + for (int i = 0; i < times; i++) { + for (Map.Entry> entry : featureMap.entrySet()) { + int type = entry.getKey(); + System.out.println("=============================" + type); + Map tag = createTag(type);//标注 + double[] feature = entry.getValue().get(i);//数据 + post(feature, tag, true); + } + } + +// for (Map.Entry> entry : featureMap.entrySet()) { +// int type = entry.getKey(); +// System.out.println("=============================" + type); +// List list = entry.getValue(); +// look(list); +// } + } + + private void post(double[] data, Map tagging, boolean isStudy) throws Exception { + List sensoryNerveList = nerveManager.getSensoryNerves(); + int size = sensoryNerveList.size(); + for (int i = 0; i < size; i++) { + sensoryNerveList.get(i).postMessage(1, data[i], isStudy, tagging, rgbBack); + } + } + + public void setFeature(int type, List feature) { + if (type > 0) { + if (featureMap.containsKey(type)) { + featureMap.get(type).addAll(feature); + } else { + featureMap.put(type, feature); + } + } + } + +} diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/RgbRegression.java b/src/main/java/org/wlld/imageRecognition/segmentation/RgbRegression.java index 972e9ee..ac47793 100644 --- a/src/main/java/org/wlld/imageRecognition/segmentation/RgbRegression.java +++ b/src/main/java/org/wlld/imageRecognition/segmentation/RgbRegression.java @@ -15,8 +15,56 @@ public class RgbRegression { private double b; private Matrix RG;//rg矩阵 private Matrix B;//b矩阵 + private Matrix RGB;//rgb矩阵 private int xIndex = 0;//记录插入数量 private boolean isRegression = false;//是否进行了回归 + private int regionNub; + private int x; + private int y; + + public Matrix getRGB() { + return RGB; + } + + public int getX() { + return x; + } + + public void setX(int x) { + this.x = x; + } + + public int getY() { + return y; + } + + public void setY(int y) { + this.y = y; + } + + public Matrix getRGMatrix() { + return RG; + } + + public Matrix getBMatrix() { + return B; + } + + public void clear(int size) { + RG = new Matrix(size, 3); + RGB = new Matrix(size, 3); + B = new Matrix(size, 1); + xIndex = 0; + regionNub = size; + } + + public int getRegionNub() { + return regionNub; + } + + public void setRegionNub(int regionNub) { + this.regionNub = regionNub; + } public double getWr() { return wr; @@ -32,11 +80,47 @@ public class RgbRegression { public RgbRegression(int size) {//初始化rgb矩阵 RG = new Matrix(size, 3); + RGB = new Matrix(size, 3); B = new Matrix(size, 1); + regionNub = size; + xIndex = 0; + } + + public void mergeRegion(RgbRegression rgbRegression) throws Exception { + Matrix myRG = rgbRegression.getRGMatrix(); + Matrix myB = rgbRegression.getBMatrix(); + int nub = rgbRegression.getRegionNub();//合并过来的数据量 + int size = nub + regionNub;//扩容后新的数据量 +// Matrix NRG = new Matrix(size, 3); +// Matrix NB = new Matrix(size, 1); +// for (int i = 0; i < size; i++) { +// Matrix RGT; +// Matrix BT; +// int t = i; +// if (i < regionNub) {//加载本来的数据 +// RGT = RG; +// BT = B; +// } else { +// RGT = myRG; +// BT = myB; +// t = t - regionNub; +// } +// for (int j = 0; j < 3; j++) { +// NRG.setNub(i, j, RGT.getNumber(t, j)); +// } +// NB.setNub(i, 0, BT.getNumber(t, 0)); +// } + regionNub = size; +// RG = NRG; +// B = NB; +// regression();//最新数据进行回归 } public void insertRGB(double[] rgb) throws Exception {//rgb插入矩阵 if (rgb.length == 3) { + RGB.setNub(xIndex, 0, rgb[0]); + RGB.setNub(xIndex, 1, rgb[1]); + RGB.setNub(xIndex, 2, rgb[2]); RG.setNub(xIndex, 0, rgb[0]); RG.setNub(xIndex, 1, rgb[1]); RG.setNub(xIndex, 2, 1.0); @@ -47,13 +131,18 @@ public class RgbRegression { } } - public void regression() throws Exception {//开始进行回归 + public boolean regression() throws Exception {//开始进行回归 if (xIndex > 0) { Matrix ws = MatrixOperation.getLinearRegression(RG, B); - wr = ws.getNumber(0, 0); - wg = ws.getNumber(1, 0); - b = ws.getNumber(2, 0); - isRegression = true; + if (ws.getX() == 1 && ws.getY() == 1) {//矩阵奇异 + isRegression = false; + } else { + wr = ws.getNumber(0, 0); + wg = ws.getNumber(1, 0); + b = ws.getNumber(2, 0); + isRegression = true; + } + return isRegression; // System.out.println("wr==" + wr + ",wg==" + wg + ",b==" + b); } else { throw new Exception("regression matrix size is zero"); diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/WFilter.java b/src/main/java/org/wlld/imageRecognition/segmentation/WFilter.java deleted file mode 100644 index 63c7e75..0000000 --- a/src/main/java/org/wlld/imageRecognition/segmentation/WFilter.java +++ /dev/null @@ -1,55 +0,0 @@ -package org.wlld.imageRecognition.segmentation; - -import org.wlld.MatrixTools.Matrix; -import org.wlld.imageRecognition.MeanClustering; -import org.wlld.imageRecognition.RGBNorm; -import org.wlld.imageRecognition.TempleConfig; -import org.wlld.imageRecognition.ThreeChannelMatrix; - -import java.util.List; - -/** - * @param - * @DATA - * @Author LiDaPeng - * @Description - */ -public class WFilter { - private List rgbNorms; - - public void filter(ThreeChannelMatrix threeChannelMatrix, TempleConfig templeConfig, - int speciesQuantity) throws Exception { - MeanClustering meanClustering = new MeanClustering(speciesQuantity, templeConfig); - Matrix matrixR = threeChannelMatrix.getMatrixR(); - Matrix matrixG = threeChannelMatrix.getMatrixG(); - Matrix matrixB = threeChannelMatrix.getMatrixB(); - int x = matrixR.getX(); - int y = matrixR.getY(); - for (int i = 0; i < x; i++) { - for (int j = 0; j < y; j++) { - double[] color = new double[]{matrixR.getNumber(i, j), matrixG.getNumber(i, j), matrixB.getNumber(i, j)}; - meanClustering.setColor(color); - } - } - meanClustering.start(true); - rgbNorms = meanClustering.getMatrices(); - - } - - private void getDist() { - double min = -1; - for (RGBNorm rgbNorm : rgbNorms) { - double[] rgb = rgbNorm.getRgb(); - - } - } - - private double dist(double[] a, double[] b) { - double sigma = 0; - for (int i = 0; i < a.length; i++) { - double sub = Math.pow(a[i] - b[i], 2); - sigma = sigma + sub; - } - return sigma / a.length; - } -} diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java index f968dc4..4885282 100644 --- a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java +++ b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java @@ -75,15 +75,17 @@ public class Watershed { private boolean isTray(int x, int y) throws Exception { boolean isTray = false; -// double[] rgb = new double[]{matrixR.getNumber(x, y) / 255, matrixG.getNumber(x, y) / 255, -// matrixB.getNumber(x, y) / 255}; -// for (RgbRegression rgbRegression : trayBody) { -// double dist = rgbRegression.getDisError(rgb); -// if (dist < trayTh) { -// isTray = true; -// break; -// } -// } + if (trayBody != null && trayBody.size() > 0) { + double[] rgb = new double[]{matrixR.getNumber(x, y) / 255, matrixG.getNumber(x, y) / 255, + matrixB.getNumber(x, y) / 255}; + for (RgbRegression rgbRegression : trayBody) { + double dist = rgbRegression.getDisError(rgb); + if (dist < trayTh) { + isTray = true; + break; + } + } + } return isTray; } diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java index 876d515..c11f463 100644 --- a/src/main/java/org/wlld/nerveEntity/Nerve.java +++ b/src/main/java/org/wlld/nerveEntity/Nerve.java @@ -186,10 +186,12 @@ public abstract class Nerve { private void backGetMessage(double parameter, long eventId) throws Exception {//反向传播 backNub++; - sigmaW = ArithUtil.add(sigmaW, parameter); + //sigmaW = ArithUtil.add(sigmaW, parameter); + sigmaW = sigmaW + parameter; if (backNub == downNub) {//进行新的梯度计算 backNub = 0; - gradient = ArithUtil.mul(activeFunction.functionG(outNub), sigmaW); + //gradient = ArithUtil.mul(activeFunction.functionG(outNub), sigmaW); + gradient = activeFunction.functionG(outNub) * sigmaW; updatePower(eventId);//修改阈值 } } @@ -218,8 +220,10 @@ public abstract class Nerve { } protected void updatePower(long eventId) throws Exception {//修改阈值 - double h = ArithUtil.mul(gradient, studyPoint);//梯度下降 - threshold = ArithUtil.add(threshold, -h);//更新阈值 + //double h = ArithUtil.mul(gradient, studyPoint);//梯度下降 + double h = gradient * studyPoint; + //threshold = ArithUtil.add(threshold, -h);//更新阈值 + threshold = threshold - h; updateW(h, eventId); sigmaW = 0;//求和结果归零 backSendMessage(eventId); @@ -229,7 +233,8 @@ public abstract class Nerve { double re = 0.0; if (rzType != RZ.NOT_RZ) { if (rzType == RZ.L2) { - re = ArithUtil.mul(param, -w); + //re = ArithUtil.mul(param, -w); + re = param * -w; } else if (rzType == RZ.L1) { if (w > 0) { re = -param; @@ -248,11 +253,15 @@ public abstract class Nerve { int key = entry.getKey();//上层隐层神经元的编号 double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重 double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入 - double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 + //double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 + double wp = bn * h; double regular = regularization(w, param);//正则化抑制权重s - w = ArithUtil.add(w, regular); - w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 - double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元 + //w = ArithUtil.add(w, regular); + w = w + regular; + //w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 + w = w + wp; + // double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元 + double dm = w * gradient; // System.out.println("allG==" + allG + ",dm==" + dm); wg.put(key, dm);//保存上一层权重与梯度的积 dendrites.put(key, w);//保存修正结果 @@ -287,9 +296,10 @@ public abstract class Nerve { double value = featuresList.get(i); double w = dendrites.get(i + 1); //System.out.println("w==" + w + ",value==" + value); - sigma = ArithUtil.add(ArithUtil.mul(w, value), sigma); + //sigma = ArithUtil.add(ArithUtil.mul(w, value), sigma); + sigma = w * value + sigma; } - return ArithUtil.sub(sigma, threshold); + return sigma - threshold;//ArithUtil.sub(sigma, threshold); } private void initPower(boolean init, boolean isDynamic) throws Exception {//初始化权重及阈值 diff --git a/src/main/java/org/wlld/nerveEntity/SoftMax.java b/src/main/java/org/wlld/nerveEntity/SoftMax.java index 4863bea..11140ee 100644 --- a/src/main/java/org/wlld/nerveEntity/SoftMax.java +++ b/src/main/java/org/wlld/nerveEntity/SoftMax.java @@ -51,7 +51,8 @@ public class SoftMax extends Nerve { private double outGradient() {//生成输出层神经元梯度变化 double g = outNub; if (E == 1) { - g = ArithUtil.sub(g, 1); + //g = ArithUtil.sub(g, 1); + g = g - 1; } return g; } @@ -63,8 +64,9 @@ public class SoftMax extends Nerve { double eSelf = Math.exp(self); for (int i = 0; i < featuresList.size(); i++) { double value = featuresList.get(i); - sigma = ArithUtil.add(Math.exp(value), sigma); + // sigma = ArithUtil.add(Math.exp(value), sigma); + sigma = Math.exp(value) + sigma; } - return ArithUtil.div(eSelf, sigma); + return eSelf / sigma;//ArithUtil.div(eSelf, sigma); } } diff --git a/src/main/java/org/wlld/param/Food.java b/src/main/java/org/wlld/param/Food.java index d873fed..9b8f133 100644 --- a/src/main/java/org/wlld/param/Food.java +++ b/src/main/java/org/wlld/param/Food.java @@ -1,5 +1,6 @@ package org.wlld.param; +import org.wlld.imageRecognition.segmentation.KNerveManger; import org.wlld.imageRecognition.segmentation.RgbRegression; import java.util.ArrayList; @@ -22,6 +23,24 @@ public class Food { private int regionSize = 5;//纹理区域大小 private int step = 1;//特征取样步长 private double dispersedTh = 0.3;//选区筛选离散阈值 + private int speciesNub = 24;//种类数 + private KNerveManger kNerveManger; + + public KNerveManger getkNerveManger() { + return kNerveManger; + } + + public void setkNerveManger(KNerveManger kNerveManger) { + this.kNerveManger = kNerveManger; + } + + public int getSpeciesNub() { + return speciesNub; + } + + public void setSpeciesNub(int speciesNub) { + this.speciesNub = speciesNub; + } public int getStep() { return step; diff --git a/src/test/java/coverTest/DataObservation.java b/src/test/java/coverTest/DataObservation.java new file mode 100644 index 0000000..16fd9e3 --- /dev/null +++ b/src/test/java/coverTest/DataObservation.java @@ -0,0 +1,101 @@ +package coverTest; + +import coverTest.regionCut.RegionCut; +import coverTest.regionCut.RegionFeature; +import org.wlld.MatrixTools.Matrix; +import org.wlld.imageRecognition.Convolution; +import org.wlld.imageRecognition.Picture; +import org.wlld.imageRecognition.ThreeChannelMatrix; +import org.wlld.imageRecognition.segmentation.RgbRegression; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description数据观察 + */ +public class DataObservation { + private static Convolution convolution = new Convolution(); + + public static void main(String[] args) throws Exception { + //372,330,右 最大值 147.44 + //377 ,330右 最大值 69.6 + int xp = 123; + int yp = 165;//290 + observation2("/Users/lidapeng/Desktop/test/testOne/a0.jpg", xp, yp); + } + + public static void observation2(String url, int xp, int yp) throws Exception { + Picture picture = new Picture(); + ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix(url); + ThreeChannelMatrix myThreeChannelMatrix = convolution.getRegionMatrix(threeChannelMatrix, xp, yp, 270, 274); + RegionFeature regionFeature = new RegionFeature(myThreeChannelMatrix, xp, yp); + regionFeature.start(); + } + + public static void observation(String url, int xp, int yp, int size) throws Exception { + Picture picture = new Picture(); + ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix(url); + ThreeChannelMatrix myThreeChannelMatrix = convolution.getRegionMatrix(threeChannelMatrix, xp, yp, size, size); + //右 + ThreeChannelMatrix threeChannelMatrix2 = convolution.getRegionMatrix(threeChannelMatrix, xp, yp + size, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix2, "右"); + //左 + ThreeChannelMatrix threeChannelMatrix3 = convolution.getRegionMatrix(threeChannelMatrix, xp, yp - size, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix3, "左"); + //上 + ThreeChannelMatrix threeChannelMatrix4 = convolution.getRegionMatrix(threeChannelMatrix, xp - size, yp, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix4, "上"); + //下 + ThreeChannelMatrix threeChannelMatrix5 = convolution.getRegionMatrix(threeChannelMatrix, xp + size, yp, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix5, "下"); + //左上 + ThreeChannelMatrix threeChannelMatrix6 = convolution.getRegionMatrix(threeChannelMatrix, xp - size, yp - size, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix6, "左上"); + //左下 + ThreeChannelMatrix threeChannelMatrix7 = convolution.getRegionMatrix(threeChannelMatrix, xp + size, yp - size, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix7, "左下"); + //右上 + ThreeChannelMatrix threeChannelMatrix8 = convolution.getRegionMatrix(threeChannelMatrix, xp - size, yp + size, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix8, "右上"); + //右下 + ThreeChannelMatrix threeChannelMatrix9 = convolution.getRegionMatrix(threeChannelMatrix, xp + size, yp + size, size, size); + getDist(myThreeChannelMatrix, threeChannelMatrix9, "右下"); + //getDist("/Users/lidapeng/Desktop/test/testOne/a1.jpg", 468, 713, rgbRegression, "测"); + + } + + private static void getDist(ThreeChannelMatrix threeChannelMatrix1, ThreeChannelMatrix threeChannelMatrix2, String name) throws Exception { + Matrix matrixR1 = threeChannelMatrix1.getMatrixR(); + Matrix matrixG1 = threeChannelMatrix1.getMatrixG(); + Matrix matrixB1 = threeChannelMatrix1.getMatrixB(); + Matrix matrixR2 = threeChannelMatrix2.getMatrixR(); + Matrix matrixG2 = threeChannelMatrix2.getMatrixG(); + Matrix matrixB2 = threeChannelMatrix2.getMatrixB(); + int x = matrixR1.getX(); + int y = matrixR1.getY(); + int nub = x * y; + double sigmaR = 0; + double sigmaG = 0; + double sigmaB = 0; + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + double subR = Math.pow(matrixR1.getNumber(i, j) - matrixR2.getNumber(i, j), 2); + double subG = Math.pow(matrixG1.getNumber(i, j) - matrixG2.getNumber(i, j), 2); + double subB = Math.pow(matrixB1.getNumber(i, j) - matrixB2.getNumber(i, j), 2); + sigmaR = subR + sigmaR; + sigmaG = subG + sigmaG; + sigmaB = subB + sigmaB; + } + } + sigmaR = sigmaR / nub; + sigmaG = sigmaG / nub; + sigmaB = sigmaB / nub; + double sigma = sigmaR + sigmaG + sigmaB; + System.out.println(name + ":" + sigma); + } +} diff --git a/src/test/java/coverTest/FoodTest.java b/src/test/java/coverTest/FoodTest.java index d501c38..ef85a42 100644 --- a/src/test/java/coverTest/FoodTest.java +++ b/src/test/java/coverTest/FoodTest.java @@ -8,6 +8,7 @@ import org.wlld.config.Classifier; import org.wlld.config.RZ; import org.wlld.config.StudyPattern; import org.wlld.imageRecognition.*; +import org.wlld.imageRecognition.segmentation.KNerveManger; import org.wlld.imageRecognition.segmentation.RegionBody; import org.wlld.imageRecognition.segmentation.RegionMapping; import org.wlld.imageRecognition.segmentation.Specifications; @@ -76,10 +77,13 @@ public class FoodTest { food.setShrink(5);//缩紧像素 food.setTimes(1);//聚类数据增强 food.setRegionSize(5); + KNerveManger kNerveManger = new KNerveManger(9, 24, 2000); + food.setkNerveManger(kNerveManger); food.setRowMark(0.15);//0.12 food.setColumnMark(0.15);//0.25 food.setRegressionNub(20000); food.setTrayTh(0.08); + food.setDispersedTh(0.5); templeConfig.setClassifier(Classifier.KNN); templeConfig.init(StudyPattern.Cover_Pattern, true, 400, 400, 3); if (modelParameter != null) { @@ -99,10 +103,11 @@ public class FoodTest { specifications.setMaxWidth(600); specifications.setMaxHeight(600); specificationsList.add(specifications); + KNerveManger kNerveManger = templeConfig.getFood().getkNerveManger(); // ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg"); // operation.setTray(threeChannelMatrix); String name = "/Users/lidapeng/Desktop/test/testOne/"; - for (int i = 0; i < 5; i++) { + for (int i = 0; i < 1; i++) { System.out.println("轮数============================" + i); ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix(name + "a" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix(name + "b" + i + ".jpg"); @@ -128,32 +133,34 @@ public class FoodTest { ThreeChannelMatrix threeChannelMatrix22 = picture.getThreeMatrix(name + "v" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix23 = picture.getThreeMatrix(name + "w" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix24 = picture.getThreeMatrix(name + "x" + i + ".jpg"); - operation.colorStudy(threeChannelMatrix1, 1, specificationsList); - operation.colorStudy(threeChannelMatrix2, 2, specificationsList); - operation.colorStudy(threeChannelMatrix3, 3, specificationsList); - operation.colorStudy(threeChannelMatrix4, 4, specificationsList); - operation.colorStudy(threeChannelMatrix5, 5, specificationsList); - operation.colorStudy(threeChannelMatrix6, 6, specificationsList); - operation.colorStudy(threeChannelMatrix7, 7, specificationsList); - operation.colorStudy(threeChannelMatrix8, 8, specificationsList); - operation.colorStudy(threeChannelMatrix9, 9, specificationsList); - operation.colorStudy(threeChannelMatrix10, 10, specificationsList); - operation.colorStudy(threeChannelMatrix11, 11, specificationsList); - operation.colorStudy(threeChannelMatrix12, 12, specificationsList); - operation.colorStudy(threeChannelMatrix13, 13, specificationsList); - operation.colorStudy(threeChannelMatrix14, 14, specificationsList); - operation.colorStudy(threeChannelMatrix15, 15, specificationsList); - operation.colorStudy(threeChannelMatrix16, 16, specificationsList); - operation.colorStudy(threeChannelMatrix17, 17, specificationsList); - operation.colorStudy(threeChannelMatrix18, 18, specificationsList); - operation.colorStudy(threeChannelMatrix19, 19, specificationsList); - operation.colorStudy(threeChannelMatrix20, 20, specificationsList); - operation.colorStudy(threeChannelMatrix21, 21, specificationsList); - operation.colorStudy(threeChannelMatrix22, 22, specificationsList); - operation.colorStudy(threeChannelMatrix23, 23, specificationsList); - operation.colorStudy(threeChannelMatrix24, 24, specificationsList); + operation.colorStudy(threeChannelMatrix1, 1, specificationsList, name); + operation.colorStudy(threeChannelMatrix2, 2, specificationsList, name); + operation.colorStudy(threeChannelMatrix3, 3, specificationsList, name); + operation.colorStudy(threeChannelMatrix4, 4, specificationsList, name); + operation.colorStudy(threeChannelMatrix5, 5, specificationsList, name); + operation.colorStudy(threeChannelMatrix6, 6, specificationsList, name); + operation.colorStudy(threeChannelMatrix7, 7, specificationsList, name); + operation.colorStudy(threeChannelMatrix8, 8, specificationsList, name); + operation.colorStudy(threeChannelMatrix9, 9, specificationsList, name); + operation.colorStudy(threeChannelMatrix10, 10, specificationsList, name); + operation.colorStudy(threeChannelMatrix11, 11, specificationsList, name); + operation.colorStudy(threeChannelMatrix12, 12, specificationsList, name); + operation.colorStudy(threeChannelMatrix13, 13, specificationsList, name); + operation.colorStudy(threeChannelMatrix14, 14, specificationsList, name); + operation.colorStudy(threeChannelMatrix15, 15, specificationsList, name); + operation.colorStudy(threeChannelMatrix16, 16, specificationsList, name); + operation.colorStudy(threeChannelMatrix17, 17, specificationsList, name); + operation.colorStudy(threeChannelMatrix18, 18, specificationsList, name); + operation.colorStudy(threeChannelMatrix19, 19, specificationsList, name); + operation.colorStudy(threeChannelMatrix20, 20, specificationsList, name); + operation.colorStudy(threeChannelMatrix21, 21, specificationsList, name); + operation.colorStudy(threeChannelMatrix22, 22, specificationsList, name); + operation.colorStudy(threeChannelMatrix23, 23, specificationsList, name); + operation.colorStudy(threeChannelMatrix24, 24, specificationsList, name); } - int i = 4; + System.out.println("========================"); + kNerveManger.startStudy(); + int i = 0; ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix(name + "a" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix(name + "b" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix(name + "c" + i + ".jpg"); @@ -179,30 +186,30 @@ public class FoodTest { ThreeChannelMatrix threeChannelMatrix22 = picture.getThreeMatrix(name + "v" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix23 = picture.getThreeMatrix(name + "w" + i + ".jpg"); ThreeChannelMatrix threeChannelMatrix24 = picture.getThreeMatrix(name + "x" + i + ".jpg"); - test3(threeChannelMatrix1, operation, specificationsList); - test3(threeChannelMatrix2, operation, specificationsList); - test3(threeChannelMatrix3, operation, specificationsList); - test3(threeChannelMatrix4, operation, specificationsList); - test3(threeChannelMatrix5, operation, specificationsList); - test3(threeChannelMatrix6, operation, specificationsList); - test3(threeChannelMatrix7, operation, specificationsList); - test3(threeChannelMatrix8, operation, specificationsList); - test3(threeChannelMatrix9, operation, specificationsList); - test3(threeChannelMatrix10, operation, specificationsList); - test3(threeChannelMatrix11, operation, specificationsList); - test3(threeChannelMatrix12, operation, specificationsList); - test3(threeChannelMatrix13, operation, specificationsList); - test3(threeChannelMatrix14, operation, specificationsList); - test3(threeChannelMatrix15, operation, specificationsList); - test3(threeChannelMatrix16, operation, specificationsList); - test3(threeChannelMatrix17, operation, specificationsList); - test3(threeChannelMatrix18, operation, specificationsList); - test3(threeChannelMatrix19, operation, specificationsList); - test3(threeChannelMatrix20, operation, specificationsList); - test3(threeChannelMatrix21, operation, specificationsList); - test3(threeChannelMatrix22, operation, specificationsList); - test3(threeChannelMatrix23, operation, specificationsList); - test3(threeChannelMatrix24, operation, specificationsList); + operation.colorLook(threeChannelMatrix1, specificationsList); + operation.colorLook(threeChannelMatrix2, specificationsList); + operation.colorLook(threeChannelMatrix3, specificationsList); + operation.colorLook(threeChannelMatrix4, specificationsList); + operation.colorLook(threeChannelMatrix5, specificationsList); + operation.colorLook(threeChannelMatrix6, specificationsList); + operation.colorLook(threeChannelMatrix7, specificationsList); + operation.colorLook(threeChannelMatrix8, specificationsList); + operation.colorLook(threeChannelMatrix9, specificationsList); + operation.colorLook(threeChannelMatrix10, specificationsList); + operation.colorLook(threeChannelMatrix11, specificationsList); + operation.colorLook(threeChannelMatrix12, specificationsList); + operation.colorLook(threeChannelMatrix13, specificationsList); + operation.colorLook(threeChannelMatrix14, specificationsList); + operation.colorLook(threeChannelMatrix15, specificationsList); + operation.colorLook(threeChannelMatrix16, specificationsList); + operation.colorLook(threeChannelMatrix17, specificationsList); + operation.colorLook(threeChannelMatrix18, specificationsList); + operation.colorLook(threeChannelMatrix19, specificationsList); + operation.colorLook(threeChannelMatrix20, specificationsList); + operation.colorLook(threeChannelMatrix21, specificationsList); + operation.colorLook(threeChannelMatrix22, specificationsList); + operation.colorLook(threeChannelMatrix23, specificationsList); + operation.colorLook(threeChannelMatrix24, specificationsList); } diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index 03c613e..b306aaa 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -4,7 +4,6 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.config.Classifier; import org.wlld.config.StudyPattern; import org.wlld.imageRecognition.*; -import org.wlld.imageRecognition.segmentation.FindMaxSimilar; import org.wlld.imageRecognition.segmentation.RegionBody; import org.wlld.imageRecognition.segmentation.Specifications; import org.wlld.imageRecognition.segmentation.Watershed; @@ -55,9 +54,7 @@ public class ForestTest { int xSize = maxX - minX; int ySize = maxY - minY; ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize); - List feature = convolution.getCenterTexture(threeChannelMatrix1, templeConfig.getFood().getRegionSize(), - templeConfig.getPoolSize(), templeConfig, templeConfig.getFeatureNub()); - System.out.println(feature); + } else { System.out.println("size===" + regionBodies.size()); } diff --git a/src/test/java/coverTest/PicTest.java b/src/test/java/coverTest/PicTest.java index cfc6fc8..3e5deed 100644 --- a/src/test/java/coverTest/PicTest.java +++ b/src/test/java/coverTest/PicTest.java @@ -36,8 +36,8 @@ public class PicTest { //testImage(right, wrong, a, b); //test(); - } + } public static void test() throws Exception {//对图像进行识别测试 Picture picture = new Picture(); diff --git a/src/test/java/coverTest/RGBBody.java b/src/test/java/coverTest/RGBBody.java new file mode 100644 index 0000000..337f872 --- /dev/null +++ b/src/test/java/coverTest/RGBBody.java @@ -0,0 +1,46 @@ +package coverTest; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description + */ +public class RGBBody { + private double r; + private double g; + private double b; + private double rgb; + + public double getR() { + return r; + } + + public void setR(double r) { + this.r = r; + } + + public double getG() { + return g; + } + + public void setG(double g) { + this.g = g; + } + + public double getB() { + return b; + } + + public void setB(double b) { + this.b = b; + } + + public double getRgb() { + return rgb; + } + + public void setRgb(double rgb) { + this.rgb = rgb; + } +} diff --git a/src/test/java/coverTest/regionCut/RegionCut.java b/src/test/java/coverTest/regionCut/RegionCut.java new file mode 100644 index 0000000..63b0dc4 --- /dev/null +++ b/src/test/java/coverTest/regionCut/RegionCut.java @@ -0,0 +1,284 @@ +package coverTest.regionCut; + +import org.wlld.Ma; +import org.wlld.MatrixTools.Matrix; +import org.wlld.config.Kernel; +import org.wlld.imageRecognition.ThreeChannelMatrix; + +import java.util.HashMap; +import java.util.Map; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 分区切割 + */ +public class RegionCut { + private Matrix matrixH; + private Matrix regionMatrix;//分区地图 + private int fatherX; + private int fatherY; + private int size; + private int id = 1;//分区id + private Map minMap = new HashMap<>();//保存最小值 + private Map maxMap = new HashMap<>();//保存最大值 + + public RegionCut(Matrix matrixH, int fatherX, int fatherY, int size) { + this.matrixH = matrixH; + this.fatherX = fatherX; + this.fatherY = fatherY; + this.size = size; + regionMatrix = new Matrix(matrixH.getX(), matrixH.getY()); + } + + private void setLimit(int id, double pixel) { + double min = minMap.get(id); + double max = maxMap.get(id); + if (pixel > max) { + maxMap.put(id, pixel); + } + if (pixel < min) { + minMap.put(id, pixel); + } + } + + private void firstCut() throws Exception {//进行第一次切割 + int x = matrixH.getX(); + int y = matrixH.getY(); + int size = x * y; + System.out.println("像素数量:" + size); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + double regionId = regionMatrix.getNumber(i, j); + if (regionId < 0.5) {//该像素没有被连接 + boolean isStop; + double self = matrixH.getNumber(i, j);//灰度值 + regionMatrix.setNub(i, j, id); + minMap.put(id, self); + maxMap.put(id, self); + //System.out.println(regionMatrix.getString()); + int xi = i; + int yj = j; + do { + double mySelf = matrixH.getNumber(xi, yj);//灰度值 + int pixel = pixelLine(xi, yj, mySelf); + int column = pixel & 0xfff; + int row = (pixel >> 12) & 0xfff; + double type = regionMatrix.getNumber(row, column); + if (type < 0.5) {//可以连接 + regionMatrix.setNub(row, column, id);//进行连接 + setLimit(id, mySelf); + double mySelfSon = matrixH.getNumber(row, column);//灰度值 + int pixelOther = pixelLine(row, column, mySelfSon); + int column2 = pixelOther & 0xfff; + int row2 = (pixelOther >> 12) & 0xfff; + isStop = row2 == xi && column2 == yj; + xi = row; + yj = column; + } else {//已经被连接了,跳出 + isStop = true; + } + } while (!isStop); + id++; + } + } + } + System.out.println("第一次选区数量:" + id); + } + + public void secondCut() throws Exception {//二切 + int x = matrixH.getX(); + int y = matrixH.getY(); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + double key = regionMatrix.getNumber(i, j);//与周围八方向比较看是否有异类 + getOther(i, j, (int) key, matrixH.getNumber(i, j)); + } + } + } + + private void updateType(int type, int toType) throws Exception { + int x = regionMatrix.getX(); + int y = regionMatrix.getY(); + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + if (regionMatrix.getNumber(i, j) == type) { + regionMatrix.setNub(i, j, toType); + } + } + } + } + + private void getOther(int x, int y, int key, double self) throws Exception { + double[] pixels = getPixels(x, y, false); + for (int i = 0; i < pixels.length; i++) { + int pix = (int) pixels[i]; + if (pix > 0 && pix != key) {//接壤的非同类 + double min = minMap.get(pix); + double max = maxMap.get(pix); + double maxDist = max - min; + int row = x; + int column = y; + switch (i) { + case 0://上 + row = x - 1; + break; + case 1://左 + column = y - 1; + break; + case 2://下 + row = x + 1; + break; + case 3://右 + column = y + 1; + break; + case 4://左上 + column = y - 1; + row = x - 1; + break; + case 5://左下 + column = y - 1; + row = x + 1; + break; + case 6://右下 + column = y + 1; + row = x + 1; + break; + case 7://右上 + column = y + 1; + row = x - 1; + break; + } + double dist = Math.abs(matrixH.getNumber(row, column) - self); + if (dist < maxDist * 0.2) {//两个选区可以合并 + setLimit(key, min); + setLimit(key, max); + updateType(pix, key); + id--; + } + break; + } + } + } + + public void start() throws Exception { + firstCut();//初切 + System.out.println("区域数量1:" + id); + for (int i = 0; i < 1; i++) { + secondCut();//二切 + } + System.out.println("区域数量2:" + id); + System.out.println(regionMatrix.getString()); + } + + private int pixelLine(int x, int y, double self) throws Exception { + double[] pixels = getPixels(x, y, true); + int minIndex = getMinIndex(pixels, self); + int row = x; + int column = y; + switch (minIndex) { + case 0://上 + row = x - 1; + break; + case 1://左 + column = y - 1; + break; + case 2://下 + row = x + 1; + break; + case 3://右 + column = y + 1; + break; + case 4://左上 + column = y - 1; + row = x - 1; + break; + case 5://左下 + column = y - 1; + row = x + 1; + break; + case 6://右下 + column = y + 1; + row = x + 1; + break; + case 7://右上 + column = y + 1; + row = x - 1; + break; + } + return row << 12 | column; + } + + private double[] getPixels(int x, int y, boolean isFirst) throws Exception { + double left = 1, leftTop = 1, leftBottom = 1, right = 1, rightTop = 1, rightBottom = 1, top = 1, bottom = 1; + Matrix matrix; + if (isFirst) { + matrix = matrixH; + } else { + matrix = regionMatrix; + } + if (x == 0) { + top = -1; + leftTop = -1; + rightTop = -1; + } + if (y == 0) { + leftTop = -1; + left = -1; + leftBottom = -1; + } + if (x == size - 1) { + leftBottom = -1; + bottom = -1; + rightBottom = -1; + } + if (y == size - 1) { + rightTop = -1; + right = -1; + rightBottom = -1; + } + if (top > 0) { + top = matrix.getNumber(x - 1, y); + } + if (left > 0) { + left = matrix.getNumber(x, y - 1); + } + if (right > 0) { + right = matrix.getNumber(x, y + 1); + } + if (bottom > 0) { + bottom = matrix.getNumber(x + 1, y); + } + if (leftTop > 0) { + leftTop = matrix.getNumber(x - 1, y - 1); + } + if (leftBottom > 0) { + leftBottom = matrix.getNumber(x + 1, y - 1); + } + if (rightTop > 0) { + rightTop = matrix.getNumber(x - 1, y + 1); + } + if (rightBottom > 0) { + rightBottom = matrix.getNumber(x + 1, y + 1); + } + return new double[]{top, left, bottom, right, leftTop, leftBottom, rightBottom, rightTop}; + + } + + private int getMinIndex(double[] array, double self) {//获取最小值 + double min = -1; + int minIdx = 0; + for (int i = 0; i < array.length; i++) { + double nub = array[i]; + if (nub > 0) { + nub = Math.abs(nub - self); + if (min < 0 || nub < min) { + min = nub; + minIdx = i; + } + } + } + return minIdx; + } +} diff --git a/src/test/java/coverTest/regionCut/RegionCutBody.java b/src/test/java/coverTest/regionCut/RegionCutBody.java new file mode 100644 index 0000000..5e0d451 --- /dev/null +++ b/src/test/java/coverTest/regionCut/RegionCutBody.java @@ -0,0 +1,15 @@ +package coverTest.regionCut; + +import java.util.ArrayList; +import java.util.List; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description + */ +public class RegionCutBody { + private List pixels = new ArrayList<>(); + +} diff --git a/src/test/java/coverTest/regionCut/RegionFeature.java b/src/test/java/coverTest/regionCut/RegionFeature.java new file mode 100644 index 0000000..068077a --- /dev/null +++ b/src/test/java/coverTest/regionCut/RegionFeature.java @@ -0,0 +1,84 @@ +package coverTest.regionCut; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.config.Kernel; +import org.wlld.imageRecognition.ThreeChannelMatrix; +import org.wlld.imageRecognition.segmentation.RgbRegression; +import org.wlld.tools.Frequency; + +import java.util.Arrays; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 利用边缘算子提取特征 + */ +public class RegionFeature extends Frequency { + private Matrix matrix; + private Matrix matrixR; + private Matrix matrixG; + private Matrix matrixB; + private Matrix kernel = Kernel.Big; + private int fatherX; + private int fatherY; + private RgbRegression rgbRegression = new RgbRegression(25); + + public RegionFeature(ThreeChannelMatrix threeChannelMatrix, int fatherX, int fatherY) { + this.matrix = threeChannelMatrix.getMatrixRGB(); + this.fatherX = fatherX; + this.fatherY = fatherY; + matrixR = threeChannelMatrix.getMatrixR(); + matrixG = threeChannelMatrix.getMatrixG(); + matrixB = threeChannelMatrix.getMatrixB(); + } + + private double getMatrixVar(Matrix matrix) throws Exception { + int x = matrix.getX(); + int y = matrix.getY(); + double[] data = new double[x * y]; + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + int index = i * y + j; + data[index] = matrix.getNumber(i, j); + } + } + //System.out.println(Arrays.toString(data)); + return variance(data); + } + + public void start() throws Exception { + int x = matrix.getX(); + int y = matrix.getY(); + int size = kernel.getX(); + for (int i = 0; i <= x - size; i += 2) { + int xSize = fatherX + i; + System.out.println("==================================" + xSize); + for (int j = 0; j <= y - size; j += 2) { + // double conNub = Math.abs(convolution(i, j));//卷积值 + double var = getMatrixVar(matrix.getSonOfMatrix(i, j, size, size)); + int ySize = fatherY + j; + System.out.println("x:" + xSize + ",y:" + ySize + ",var:" + var); + + + } + } + } + + private double convolution(int x, int y) throws Exception {//计算卷积 + double allNub = 0; + int xr; + int yr; + int kxMax = kernel.getX(); + int kyMax = kernel.getY(); + for (int i = 0; i < kxMax; i++) { + xr = i + x; + for (int j = 0; j < kyMax; j++) { + yr = j + y; + allNub = matrix.getNumber(xr, yr) * kernel.getNumber(i, j) + allNub; + } + } + return allNub; + } +} diff --git a/src/test/java/org/wlld/NerveDemo1.java b/src/test/java/org/wlld/NerveDemo1.java index 76806f2..9b2752c 100644 --- a/src/test/java/org/wlld/NerveDemo1.java +++ b/src/test/java/org/wlld/NerveDemo1.java @@ -36,13 +36,13 @@ public class NerveDemo1 { * @param isDynamic 是否是动态神经元 */ NerveManager nerveManager = new NerveManager(2, 6, 1, 2, new Tanh(), - false, true, 0, RZ.NOT_RZ, 0); + false, false, 0, RZ.NOT_RZ, 0); nerveManager.init(true, false, false, false); //创建训练 List> list_right = new LinkedList<>();//存放正确的值 List> list_wrong = new LinkedList<>();//存放错误的值 Random random = new Random(); - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < 3000; i++) { Map mp1 = new HashMap<>(); Map mp2 = new HashMap<>(); mp1.put(0, random.nextDouble());