diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index af2e0d9..07ff605 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -296,7 +296,9 @@ public class Operation {//进行计算 private void lvq(int tagging, Matrix myMatrix) throws Exception {//LVQ学习 LVQ lvq = templeConfig.getLvq(); Matrix vector = MatrixOperation.matrixToVector(myMatrix, true); - System.out.println(vector.getString()); + if (templeConfig.isShowLog()) { + System.out.println(vector.getString()); + } MatrixBody matrixBody = new MatrixBody(); matrixBody.setMatrix(vector); matrixBody.setId(tagging); diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index 848923e..f9e0273 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -262,7 +262,7 @@ public class TempleConfig { private void initNerveManager(boolean initPower, int sensoryNerveNub , int deep, double studyPoint) throws Exception { - nerveManager = new NerveManager(sensoryNerveNub, 9, + nerveManager = new NerveManager(sensoryNerveNub, 6, classificationNub, deep, activeFunction, false, isAccurate, studyPoint); nerveManager.init(initPower, false, isShowLog); } diff --git a/src/main/java/org/wlld/imageRecognition/border/LVQ.java b/src/main/java/org/wlld/imageRecognition/border/LVQ.java index 9455891..24af36f 100644 --- a/src/main/java/org/wlld/imageRecognition/border/LVQ.java +++ b/src/main/java/org/wlld/imageRecognition/border/LVQ.java @@ -16,7 +16,7 @@ public class LVQ { private int typeNub;//原型聚类个数,即分类个数(需要模型返回) private MatrixBody[] model;//原型向量(需要模型返回) private List matrixList = new ArrayList<>(); - private double studyPoint = 0.1;//量化学习率 + private double studyPoint = 0.0001;//量化学习率 private int length;//向量长度(需要返回) private boolean isReady = false; private int lvqNub; @@ -86,8 +86,6 @@ public class LVQ { long type = matrixBody.getId();//类别 double distEnd = 0; int id = 0; - double dis0 = 0; - double dis1 = 1; for (int i = 0; i < typeNub; i++) { MatrixBody modelBody = model[i]; Matrix modelMatrix = modelBody.getMatrix(); @@ -97,16 +95,10 @@ public class LVQ { id = modelBody.getId(); distEnd = dist; } - if (i == 0) { - dis0 = dist; - } else { - dis1 = dist; - } } MatrixBody modelBody = model[id]; Matrix modelMatrix = modelBody.getMatrix(); boolean isRight = id == type; - System.out.println("type==" + type + ",dis0==" + dis0 + ",dis1==" + dis1); Matrix matrix1 = op(matrix, modelMatrix, isRight); modelBody.setMatrix(matrix1); } diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java index f9b9924..644ffd2 100644 --- a/src/main/java/org/wlld/nerveEntity/Nerve.java +++ b/src/main/java/org/wlld/nerveEntity/Nerve.java @@ -223,11 +223,13 @@ public abstract class Nerve { private void updateW(double h, long eventId) {//h是学习率 * 当前g(梯度) List list = features.get(eventId); + double stop = ArithUtil.sub(1, ArithUtil.div(ArithUtil.mul(studyPoint, 0.015), dendrites.size())); for (Map.Entry entry : dendrites.entrySet()) { int key = entry.getKey();//上层隐层神经元的编号 double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重 double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入 double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 + w = ArithUtil.mul(w, stop); w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元 // System.out.println("allG==" + allG + ",dm==" + dm); diff --git a/src/test/java/coverTest/FoodTest.java b/src/test/java/coverTest/FoodTest.java index 4ffa518..bfec293 100644 --- a/src/test/java/coverTest/FoodTest.java +++ b/src/test/java/coverTest/FoodTest.java @@ -21,7 +21,7 @@ public class FoodTest { public static void food() throws Exception { Picture picture = new Picture(); - TempleConfig templeConfig = new TempleConfig(false, true); + TempleConfig templeConfig = new TempleConfig(false, false); templeConfig.setClassifier(Classifier.DNN); templeConfig.isShowLog(true); templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4); @@ -60,7 +60,7 @@ public class FoodTest { // } // templeConfig.getNormalization().avg(); for (int j = 0; j < 1; j++) { - for (int i = 1; i < 1900; i++) { + for (int i = 1; i < 1500; i++) { System.out.println("j==" + j + ",study2==================" + i); //读取本地URL地址图片,并转化成矩阵 Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); @@ -88,7 +88,7 @@ public class FoodTest { // Operation operation2 = new Operation(templeConfig2); int wrong = 0; int allNub = 0; - for (int i = 1900; i <= 1998; i++) { + for (int i = 1500; i <= 1600; i++) { //读取本地URL地址图片,并转化成矩阵 Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg"); diff --git a/src/test/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java index 26eef72..3c916c7 100644 --- a/src/test/java/org/wlld/HelloWorld.java +++ b/src/test/java/org/wlld/HelloWorld.java @@ -79,7 +79,7 @@ public class HelloWorld { } templeConfig.getNormalization().avg(); //三阶段学习 - for (int i = 1; i < 1900; i++) { + for (int i = 1; i < 1000; i++) { System.out.println("study2==================" + i); //读取本地URL地址图片,并转化成矩阵 Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");