From 2501596973d155ee3fda876f603aacb06a641682 Mon Sep 17 00:00:00 2001 From: lidapeng Date: Sat, 18 Jan 2020 11:12:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=A8=A1=E5=9E=8B=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=8E=B7=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 41 ---------- src/main/java/org/wlld/HelloWorld.java | 4 +- .../wlld/imageRecognition/TempleConfig.java | 11 ++- .../org/wlld/nerveCenter/NerveManager.java | 80 ++++++++++++++++++- .../org/wlld/nerveEntity/DymNerveStudy.java | 30 +++++++ .../org/wlld/nerveEntity/ModelParameter.java | 18 +++++ src/main/java/org/wlld/nerveEntity/Nerve.java | 8 ++ .../java/org/wlld/nerveEntity/OutNerve.java | 2 +- 8 files changed, 144 insertions(+), 50 deletions(-) create mode 100644 src/main/java/org/wlld/nerveEntity/DymNerveStudy.java diff --git a/pom.xml b/pom.xml index dc88620..712ab8a 100644 --- a/pom.xml +++ b/pom.xml @@ -11,22 +11,6 @@ myBrain http://www.example.com - - - - The Apache Software License, Version 2.0 - http://www.apache.org/licenses/LICENSE-2.0.txt - repo - - - - - thenk008 - 794757862@qq.com - hope-redheart - https://www.cnblogs.com/yjp372928571 - - UTF-8 1.8 @@ -48,32 +32,7 @@ - - org.apache.maven.plugins - maven-gpg-plugin - 1.6 - - - verify - - sign - - - - - - - releases - Nexus Release Repository - https://oss.sonatype.org/service/local/staging/deploy/maven2 - - - snapshots - Nexus Snapshot Repository - https://oss.sonatype.org/content/repositories/snapshots - - diff --git a/src/main/java/org/wlld/HelloWorld.java b/src/main/java/org/wlld/HelloWorld.java index 704aa56..6a022d0 100644 --- a/src/main/java/org/wlld/HelloWorld.java +++ b/src/main/java/org/wlld/HelloWorld.java @@ -29,7 +29,7 @@ public class HelloWorld { Map wrongTagging = new HashMap<>();//分类标注 rightTagging.put(1, 1.0); wrongTagging.put(1, 0.0); - for (int i = 1; i < 500; i++) { + for (int i = 1; i < 5; i++) { System.out.println("开始学习1==" + i); //读取本地URL地址图片,并转化成矩阵 Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png"); @@ -38,7 +38,7 @@ public class HelloWorld { operation.learning(right, rightTagging, false); operation.learning(wrong, wrongTagging, false); } - for (int i = 1; i < 500; i++) {//神经网络学习 + for (int i = 1; i < 5; i++) {//神经网络学习 System.out.println("开始学习2==" + i); //读取本地URL地址图片,并转化成矩阵 Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png"); diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index ab5b1d0..a44a69c 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -1,7 +1,6 @@ package org.wlld.imageRecognition; import org.wlld.MatrixTools.Matrix; -import org.wlld.MatrixTools.MatrixOperation; import org.wlld.config.StudyPattern; import org.wlld.function.ReLu; import org.wlld.function.Sigmod; @@ -99,8 +98,14 @@ public class TempleConfig { convolutionNerveManager.init(initPower, true, nerveManager); } - public ModelParameter getModel() {//获取模型参数 - return nerveManager.getModelParameter(); + public ModelParameter getModel() throws Exception {//获取模型参数 + ModelParameter modelParameter = nerveManager.getModelParameter(); + if (studyPattern == StudyPattern.Accuracy_Pattern) { + ModelParameter modelParameter1 = convolutionNerveManager.getModelParameter(); + modelParameter.setDymNerveStudies(modelParameter1.getDymNerveStudies()); + modelParameter.setDymOutNerveStudy(modelParameter1.getDymOutNerveStudy()); + } + return modelParameter; } public List getSensoryNerves() {//获取感知神经元 diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java index d4b4fb8..52726c5 100644 --- a/src/main/java/org/wlld/nerveCenter/NerveManager.java +++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java @@ -4,6 +4,7 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.i.ActiveFunction; import org.wlld.i.OutBack; import org.wlld.nerveEntity.*; +import org.wlld.tools.ArithUtil; import java.util.ArrayList; import java.util.HashMap; @@ -48,7 +49,46 @@ public class NerveManager { } } - public ModelParameter getModelParameter() {//获取当前模型参数 + private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数 + ModelParameter modelParameter = new ModelParameter(); + List dymNerveStudies = new ArrayList<>();//动态神经元隐层 + DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层 + modelParameter.setDymNerveStudies(dymNerveStudies); + modelParameter.setDymOutNerveStudy(dymOutNerveStudy); + for (int i = 0; i < depthNerves.size(); i++) { + Nerve depthNerve = depthNerves.get(i).get(0);//隐层神经元 + DymNerveStudy deepNerveStudy = new DymNerveStudy();//动态神经元输出层 + List list = deepNerveStudy.getList(); + deepNerveStudy.setThreshold(depthNerve.getThreshold());//获取偏移值 + Matrix matrix = depthNerve.getNerveMatrix(); + insertWList(matrix, list); + dymNerveStudies.add(deepNerveStudy); + } + Nerve outNerve = outNevers.get(0); + Matrix matrix = outNerve.getNerveMatrix(); + dymOutNerveStudy.setThreshold(outNerve.getThreshold()); + List list = dymOutNerveStudy.getList(); + insertWList(matrix, list); + return modelParameter; + } + + private void insertWList(Matrix matrix, List list) throws Exception {// + for (int i = 0; i < matrix.getX(); i++) { + for (int j = 0; j < matrix.getY(); j++) { + list.add(matrix.getNumber(i, j)); + } + } + } + + public ModelParameter getModelParameter() throws Exception { + if (isDynamic) { + return getDymModelParameter(); + } else { + return getStaticModelParameter(); + } + } + + private ModelParameter getStaticModelParameter() {//获取当前模型参数 ModelParameter modelParameter = new ModelParameter(); List> studyDepthNerves = new ArrayList<>();//隐层神经元模型 List outStudyNevers = new ArrayList<>();//输出神经元 @@ -78,8 +118,42 @@ public class NerveManager { return modelParameter; } - //注入模型参数 - public void insertModelParameter(ModelParameter modelParameter) { + public void insertModelParameter(ModelParameter modelParameter) throws Exception { + insertBpModelParameter(modelParameter);//全连接层注入参数 + if (isDynamic) { + insertConvolutionModelParameter(modelParameter); + } + } + + //注入卷积层模型参数 + private void insertConvolutionModelParameter(ModelParameter modelParameter) throws Exception { + List dymNerveStudyList = modelParameter.getDymNerveStudies(); + DymNerveStudy dymOutNerveStudy = modelParameter.getDymOutNerveStudy(); + for (int i = 0; i < depthNerves.size(); i++) { + Nerve depthNerve = depthNerves.get(i).get(0); + DymNerveStudy dymNerveStudy = dymNerveStudyList.get(i); + List list = dymNerveStudy.getList(); + Matrix nerveMatrix = depthNerve.getNerveMatrix(); + depthNerve.setThreshold(dymNerveStudy.getThreshold());//注入偏置项 + insertMatrix(nerveMatrix, list); + } + Nerve outNerve = outNevers.get(0); + outNerve.setThreshold(dymOutNerveStudy.getThreshold());//输出神经元注入偏置项 + Matrix outNervMatrix = outNerve.getNerveMatrix(); + List list = dymOutNerveStudy.getList(); + insertMatrix(outNervMatrix, list); + } + + private void insertMatrix(Matrix matrix, List list) throws Exception { + for (int i = 0; i < list.size(); i++) { + int x = i / 3; + int y = i % 3; + matrix.setNub(x, y, list.get(i)); + } + } + + //注入全连接模型参数 + private void insertBpModelParameter(ModelParameter modelParameter) { List> depthStudyNerves = modelParameter.getDepthNerves();//隐层神经元 List outStudyNevers = modelParameter.getOutNevers();//输出神经元 //隐层神经元参数注入 diff --git a/src/main/java/org/wlld/nerveEntity/DymNerveStudy.java b/src/main/java/org/wlld/nerveEntity/DymNerveStudy.java new file mode 100644 index 0000000..af14e95 --- /dev/null +++ b/src/main/java/org/wlld/nerveEntity/DymNerveStudy.java @@ -0,0 +1,30 @@ +package org.wlld.nerveEntity; + +import java.util.ArrayList; +import java.util.List; + +/** + * @author lidapeng + * @description 动态神经元模型参数 + * @date 8:14 上午 2020/1/18 + */ +public class DymNerveStudy { + private List list = new ArrayList<>(); + private double threshold;//此神经元的阈值需要取出 + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + public double getThreshold() { + return threshold; + } + + public void setThreshold(double threshold) { + this.threshold = threshold; + } +} diff --git a/src/main/java/org/wlld/nerveEntity/ModelParameter.java b/src/main/java/org/wlld/nerveEntity/ModelParameter.java index 2e00cf7..daf82d4 100644 --- a/src/main/java/org/wlld/nerveEntity/ModelParameter.java +++ b/src/main/java/org/wlld/nerveEntity/ModelParameter.java @@ -14,6 +14,24 @@ public class ModelParameter { //神经远模型参数 private List> depthNerves = new ArrayList<>();//隐层神经元 private List outNevers = new ArrayList<>();//输出神经元 + private List dymNerveStudies = new ArrayList<>();//动态神经元隐层 + private DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层 + + public List getDymNerveStudies() { + return dymNerveStudies; + } + + public void setDymNerveStudies(List dymNerveStudies) { + this.dymNerveStudies = dymNerveStudies; + } + + public DymNerveStudy getDymOutNerveStudy() { + return dymOutNerveStudy; + } + + public void setDymOutNerveStudy(DymNerveStudy dymOutNerveStudy) { + this.dymOutNerveStudy = dymOutNerveStudy; + } public List> getDepthNerves() { return depthNerves; diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java index 9b88a26..9b2f2d7 100644 --- a/src/main/java/org/wlld/nerveEntity/Nerve.java +++ b/src/main/java/org/wlld/nerveEntity/Nerve.java @@ -39,6 +39,14 @@ public abstract class Nerve { return dendrites; } + public Matrix getNerveMatrix() { + return nerveMatrix; + } + + public void setNerveMatrix(Matrix nerveMatrix) { + this.nerveMatrix = nerveMatrix; + } + public void setDendrites(Map dendrites) { this.dendrites = dendrites; } diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index 4353ab0..bb1397f 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -69,7 +69,7 @@ public class OutNerve extends Nerve { matrixF = new Matrix(myMatrix.getX(), myMatrix.getY()); } if (isKernelStudy) {//回传 - // System.out.println(myMatrix.getString()); + // System.out.println(myMatrix.getString()); for (Map.Entry entry : E.entrySet()) { double g; if (entry.getValue() > 0.5) {//正模板