diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index 189aabb..601ead1 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -26,9 +26,9 @@ import java.util.Map; public class TempleConfig { private NerveManager nerveManager;//神经网络管理器 private NerveManager convolutionNerveManager;//卷积神经网络管理器 - private NerveManager convolutionNerveManagerR;//卷积神经网络管理器 - private NerveManager convolutionNerveManagerG;//卷积神经网络管理器 - private NerveManager convolutionNerveManagerB;//卷积神经网络管理器 + private NerveManager convolutionNerveManagerR;//R卷积神经网络管理器 + private NerveManager convolutionNerveManagerG;//G卷积神经网络管理器 + private NerveManager convolutionNerveManagerB;//B卷积神经网络管理器 private boolean isAccurate = false;//是否保留精度 private int row = 5;//行的最小比例 private int column = 3;//列的最小比例 @@ -48,6 +48,7 @@ public class TempleConfig { private Normalization normalization = new Normalization();//统一归一化 private double avg = 0;//覆盖均值 private int sensoryNerveNub;//输入神经元个数 + private boolean isShowLog = false; public boolean isAccurate() { return isAccurate; @@ -117,6 +118,27 @@ public class TempleConfig { } + public void isShowLog(boolean isShowLog) {//是否打印学习数据 + this.isShowLog = isShowLog; + } + + public void startLvq() throws Exception { + switch (classifier) { + case Classifier.LVQ: + lvq.start(); + break; + case Classifier.VAvg: + vectorK.study(); + break; + } + if (isHavePosition) { + for (Map.Entry entry : kClusteringMap.entrySet()) { + entry.getValue().start(); + } + boxReady = true; + } + } + private Map kClusteringMap = new HashMap<>(); public Map getKClusteringMap() { @@ -223,7 +245,7 @@ public class TempleConfig { , int deep) throws Exception { nerveManager = new NerveManager(sensoryNerveNub, 9, classificationNub, deep, new Sigmod(), false, isAccurate); - nerveManager.init(initPower, false); + nerveManager.init(initPower, false, isShowLog); } private void initConvolutionVision(boolean initPower, int width, int height) throws Exception {//精准模式 @@ -277,7 +299,7 @@ public class TempleConfig { NerveManager convolutionNerveManager = new NerveManager(1, 1, 1, deep - 1, new ReLu(), true, isAccurate); convolutionNerveManager.setMatrixMap(matrixMap);//给卷积网络管理器注入期望矩阵 - convolutionNerveManager.init(initPower, true); + convolutionNerveManager.init(initPower, true, isShowLog); return convolutionNerveManager; } diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java index 0289c44..4c6bbc4 100644 --- a/src/main/java/org/wlld/nerveCenter/NerveManager.java +++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java @@ -30,7 +30,7 @@ public class NerveManager { private Map matrixMap = new HashMap<>();//主键与期望矩阵的映射 private boolean isDynamic;//是否是动态神经网络 private List studyList = new ArrayList<>(); - private boolean isAccurate = false;//是否保留精度 + private boolean isAccurate;//是否保留精度 public List getStudyList() {//查看每一次的学习率 return studyList; @@ -259,7 +259,7 @@ public class NerveManager { * @param isMatrix 参数是否是一个矩阵 * @throws Exception */ - public void init(boolean initPower, boolean isMatrix) throws Exception {//进行神经网络的初始化构建 + public void init(boolean initPower, boolean isMatrix, boolean isShowLog) throws Exception {//进行神经网络的初始化构建 this.initPower = initPower; initDepthNerve(isMatrix);//初始化深度隐层神经元 List nerveList = depthNerves.get(0);//第一层隐层神经元 @@ -267,7 +267,8 @@ public class NerveManager { List lastNeveList = depthNerves.get(depthNerves.size() - 1); //初始化输出神经元 for (int i = 1; i < outNerveNub + 1; i++) { - OutNerve outNerve = new OutNerve(i, hiddenNerveNub, 0, studyPoint, initPower, activeFunction, isMatrix, isAccurate); + OutNerve outNerve = new OutNerve(i, hiddenNerveNub, 0, studyPoint, initPower, + activeFunction, isMatrix, isAccurate, isShowLog); if (isMatrix) {//是卷积层神经网络 outNerve.setMatrixMap(matrixMap); } diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index bfb673d..0bafaec 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -19,10 +19,13 @@ import java.util.Map; */ public class OutNerve extends Nerve { private Map matrixMapE;//主键与期望矩阵的映射 + private boolean isShowLog; public OutNerve(int id, int upNub, int downNub, double studyPoint, boolean init, - ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate) throws Exception { + ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate + , boolean isShowLog) throws Exception { super(id, upNub, "OutNerve", downNub, studyPoint, init, activeFunction, isDynamic, isAccurate); + this.isShowLog = isShowLog; } @@ -45,7 +48,9 @@ public class OutNerve extends Nerve { } else { this.E = 0; } - System.out.println("E===" + this.E + ",out==" + out + ",nerveId==" + getId()); + if (isShowLog) { + System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + getId()); + } gradient = outGradient();//当前梯度变化 //调整权重 修改阈值 并进行反向传播 updatePower(eventId); @@ -66,8 +71,10 @@ public class OutNerve extends Nerve { Matrix myMatrix = dynamicNerve(matrix, eventId, isKernelStudy); if (isKernelStudy) {//回传 Matrix matrix1 = matrixMapE.get(E); - System.out.println("E================" + E); - System.out.println(myMatrix.getString()); + if (isShowLog) { + System.out.println("E================" + E); + System.out.println(myMatrix.getString()); + } if (matrix1.getX() <= myMatrix.getX() && matrix1.getY() <= myMatrix.getY()) { double g = getGradient(myMatrix, matrix1); backMatrix(g, eventId); diff --git a/src/test/java/org/wlld/App.java b/src/test/java/org/wlld/App.java index 98f15b7..ec718e9 100644 --- a/src/test/java/org/wlld/App.java +++ b/src/test/java/org/wlld/App.java @@ -27,7 +27,7 @@ public class App { public static void test3() throws Exception { NerveManager nerveManager = new NerveManager(3, 6, 3 , 3, new Sigmod(), false, true); - nerveManager.init(true, false);//初始化 + nerveManager.init(true, false, false);//初始化 List> data = new ArrayList<>();//正样本 List> dataB = new ArrayList<>();//负样本 List> dataC = new ArrayList<>();//负样本 diff --git a/src/test/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java index 860a6d1..9a8cb4c 100644 --- a/src/test/java/org/wlld/HelloWorld.java +++ b/src/test/java/org/wlld/HelloWorld.java @@ -125,8 +125,9 @@ public class HelloWorld { // frame.setLengthWidth(640); // templeConfig.setFrame(frame); templeConfig.setClassifier(Classifier.DNN); + //templeConfig.isShowLog(true); templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 2); -// ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA, ModelParameter.class); +// ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA2, ModelParameter.class); // templeConfig.insertModel(modelParameter2); Operation operation = new Operation(templeConfig); //a b c d 物品 e是背景 @@ -135,9 +136,9 @@ public class HelloWorld { for (int i = 1; i < 1900; i++) {//一阶段 System.out.println("study1===================" + i); //读取本地URL地址图片,并转化成矩阵 - Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); + Matrix a = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/picture/a" + i + ".jpg"); //Matrix b = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/picture/b" + i + ".jpg"); - Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg"); + Matrix c = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/picture/c" + i + ".jpg"); //Matrix d = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/picture/d" + i + ".jpg"); //Matrix f = picture.getImageMatrixByLocal("D:\\share\\picture/f" + i + ".png"); //将图像矩阵和标注加入进行学习,Accuracy_Pattern 模式 进行第二次学习 @@ -149,8 +150,7 @@ public class HelloWorld { //operation.learning(d, 4, false); } } -// ModelParameter modelParameter = JSON.parseObject(ModelData.DATA8, ModelParameter.class); -// templeConfig.insertModel(modelParameter); + //二阶段 for (int i = 1; i < 1900; i++) { System.out.println("avg==" + i); @@ -194,8 +194,8 @@ public class HelloWorld { // TempleConfig templeConfig2 = new TempleConfig(false); // templeConfig2.init(StudyPattern.Accuracy_Pattern, true, 1000, 1000, 2); // templeConfig2.insertModel(modelParameter2); -// -// Operation operation2 = new Operation(templeConfig2); + + // Operation operation2 = new Operation(templeConfig2); int wrong = 0; int allNub = 0; for (int i = 1900; i <= 2000; i++) { @@ -211,6 +211,7 @@ public class HelloWorld { int cn = operation.toSee(c); if (an != 1) { wrong++; + } else { } if (cn != 2) { wrong++; diff --git a/src/test/java/org/wlld/NerveDemo1.java b/src/test/java/org/wlld/NerveDemo1.java index 2d9e368..5761919 100644 --- a/src/test/java/org/wlld/NerveDemo1.java +++ b/src/test/java/org/wlld/NerveDemo1.java @@ -33,7 +33,7 @@ public class NerveDemo1 { * @param isDynamic 是否是动态神经元 */ NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), false, true); - nerveManager.init(true, false); + nerveManager.init(true, false, false); //创建训练