diff --git a/pom.xml b/pom.xml
index dc88620..712ab8a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -11,22 +11,6 @@
myBrain
http://www.example.com
-
-
-
- The Apache Software License, Version 2.0
- http://www.apache.org/licenses/LICENSE-2.0.txt
- repo
-
-
-
-
- thenk008
- 794757862@qq.com
- hope-redheart
- https://www.cnblogs.com/yjp372928571
-
-
UTF-8
1.8
@@ -48,32 +32,7 @@
-
- org.apache.maven.plugins
- maven-gpg-plugin
- 1.6
-
-
- verify
-
- sign
-
-
-
-
-
-
- releases
- Nexus Release Repository
- https://oss.sonatype.org/service/local/staging/deploy/maven2
-
-
- snapshots
- Nexus Snapshot Repository
- https://oss.sonatype.org/content/repositories/snapshots
-
-
diff --git a/src/main/java/org/wlld/HelloWorld.java b/src/main/java/org/wlld/HelloWorld.java
index 704aa56..6a022d0 100644
--- a/src/main/java/org/wlld/HelloWorld.java
+++ b/src/main/java/org/wlld/HelloWorld.java
@@ -29,7 +29,7 @@ public class HelloWorld {
Map wrongTagging = new HashMap<>();//分类标注
rightTagging.put(1, 1.0);
wrongTagging.put(1, 0.0);
- for (int i = 1; i < 500; i++) {
+ for (int i = 1; i < 5; i++) {
System.out.println("开始学习1==" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");
@@ -38,7 +38,7 @@ public class HelloWorld {
operation.learning(right, rightTagging, false);
operation.learning(wrong, wrongTagging, false);
}
- for (int i = 1; i < 500; i++) {//神经网络学习
+ for (int i = 1; i < 5; i++) {//神经网络学习
System.out.println("开始学习2==" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + i + ".png");
diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java
index ab5b1d0..a44a69c 100644
--- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java
+++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java
@@ -1,7 +1,6 @@
package org.wlld.imageRecognition;
import org.wlld.MatrixTools.Matrix;
-import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.config.StudyPattern;
import org.wlld.function.ReLu;
import org.wlld.function.Sigmod;
@@ -99,8 +98,14 @@ public class TempleConfig {
convolutionNerveManager.init(initPower, true, nerveManager);
}
- public ModelParameter getModel() {//获取模型参数
- return nerveManager.getModelParameter();
+ public ModelParameter getModel() throws Exception {//获取模型参数
+ ModelParameter modelParameter = nerveManager.getModelParameter();
+ if (studyPattern == StudyPattern.Accuracy_Pattern) {
+ ModelParameter modelParameter1 = convolutionNerveManager.getModelParameter();
+ modelParameter.setDymNerveStudies(modelParameter1.getDymNerveStudies());
+ modelParameter.setDymOutNerveStudy(modelParameter1.getDymOutNerveStudy());
+ }
+ return modelParameter;
}
public List getSensoryNerves() {//获取感知神经元
diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java
index d4b4fb8..52726c5 100644
--- a/src/main/java/org/wlld/nerveCenter/NerveManager.java
+++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java
@@ -4,6 +4,7 @@ import org.wlld.MatrixTools.Matrix;
import org.wlld.i.ActiveFunction;
import org.wlld.i.OutBack;
import org.wlld.nerveEntity.*;
+import org.wlld.tools.ArithUtil;
import java.util.ArrayList;
import java.util.HashMap;
@@ -48,7 +49,46 @@ public class NerveManager {
}
}
- public ModelParameter getModelParameter() {//获取当前模型参数
+ private ModelParameter getDymModelParameter() throws Exception {//获取动态神经元参数
+ ModelParameter modelParameter = new ModelParameter();
+ List dymNerveStudies = new ArrayList<>();//动态神经元隐层
+ DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层
+ modelParameter.setDymNerveStudies(dymNerveStudies);
+ modelParameter.setDymOutNerveStudy(dymOutNerveStudy);
+ for (int i = 0; i < depthNerves.size(); i++) {
+ Nerve depthNerve = depthNerves.get(i).get(0);//隐层神经元
+ DymNerveStudy deepNerveStudy = new DymNerveStudy();//动态神经元输出层
+ List list = deepNerveStudy.getList();
+ deepNerveStudy.setThreshold(depthNerve.getThreshold());//获取偏移值
+ Matrix matrix = depthNerve.getNerveMatrix();
+ insertWList(matrix, list);
+ dymNerveStudies.add(deepNerveStudy);
+ }
+ Nerve outNerve = outNevers.get(0);
+ Matrix matrix = outNerve.getNerveMatrix();
+ dymOutNerveStudy.setThreshold(outNerve.getThreshold());
+ List list = dymOutNerveStudy.getList();
+ insertWList(matrix, list);
+ return modelParameter;
+ }
+
+ private void insertWList(Matrix matrix, List list) throws Exception {//
+ for (int i = 0; i < matrix.getX(); i++) {
+ for (int j = 0; j < matrix.getY(); j++) {
+ list.add(matrix.getNumber(i, j));
+ }
+ }
+ }
+
+ public ModelParameter getModelParameter() throws Exception {
+ if (isDynamic) {
+ return getDymModelParameter();
+ } else {
+ return getStaticModelParameter();
+ }
+ }
+
+ private ModelParameter getStaticModelParameter() {//获取当前模型参数
ModelParameter modelParameter = new ModelParameter();
List> studyDepthNerves = new ArrayList<>();//隐层神经元模型
List outStudyNevers = new ArrayList<>();//输出神经元
@@ -78,8 +118,42 @@ public class NerveManager {
return modelParameter;
}
- //注入模型参数
- public void insertModelParameter(ModelParameter modelParameter) {
+ public void insertModelParameter(ModelParameter modelParameter) throws Exception {
+ insertBpModelParameter(modelParameter);//全连接层注入参数
+ if (isDynamic) {
+ insertConvolutionModelParameter(modelParameter);
+ }
+ }
+
+ //注入卷积层模型参数
+ private void insertConvolutionModelParameter(ModelParameter modelParameter) throws Exception {
+ List dymNerveStudyList = modelParameter.getDymNerveStudies();
+ DymNerveStudy dymOutNerveStudy = modelParameter.getDymOutNerveStudy();
+ for (int i = 0; i < depthNerves.size(); i++) {
+ Nerve depthNerve = depthNerves.get(i).get(0);
+ DymNerveStudy dymNerveStudy = dymNerveStudyList.get(i);
+ List list = dymNerveStudy.getList();
+ Matrix nerveMatrix = depthNerve.getNerveMatrix();
+ depthNerve.setThreshold(dymNerveStudy.getThreshold());//注入偏置项
+ insertMatrix(nerveMatrix, list);
+ }
+ Nerve outNerve = outNevers.get(0);
+ outNerve.setThreshold(dymOutNerveStudy.getThreshold());//输出神经元注入偏置项
+ Matrix outNervMatrix = outNerve.getNerveMatrix();
+ List list = dymOutNerveStudy.getList();
+ insertMatrix(outNervMatrix, list);
+ }
+
+ private void insertMatrix(Matrix matrix, List list) throws Exception {
+ for (int i = 0; i < list.size(); i++) {
+ int x = i / 3;
+ int y = i % 3;
+ matrix.setNub(x, y, list.get(i));
+ }
+ }
+
+ //注入全连接模型参数
+ private void insertBpModelParameter(ModelParameter modelParameter) {
List> depthStudyNerves = modelParameter.getDepthNerves();//隐层神经元
List outStudyNevers = modelParameter.getOutNevers();//输出神经元
//隐层神经元参数注入
diff --git a/src/main/java/org/wlld/nerveEntity/DymNerveStudy.java b/src/main/java/org/wlld/nerveEntity/DymNerveStudy.java
new file mode 100644
index 0000000..af14e95
--- /dev/null
+++ b/src/main/java/org/wlld/nerveEntity/DymNerveStudy.java
@@ -0,0 +1,30 @@
+package org.wlld.nerveEntity;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * @author lidapeng
+ * @description 动态神经元模型参数
+ * @date 8:14 上午 2020/1/18
+ */
+public class DymNerveStudy {
+ private List list = new ArrayList<>();
+ private double threshold;//此神经元的阈值需要取出
+
+ public List getList() {
+ return list;
+ }
+
+ public void setList(List list) {
+ this.list = list;
+ }
+
+ public double getThreshold() {
+ return threshold;
+ }
+
+ public void setThreshold(double threshold) {
+ this.threshold = threshold;
+ }
+}
diff --git a/src/main/java/org/wlld/nerveEntity/ModelParameter.java b/src/main/java/org/wlld/nerveEntity/ModelParameter.java
index 2e00cf7..daf82d4 100644
--- a/src/main/java/org/wlld/nerveEntity/ModelParameter.java
+++ b/src/main/java/org/wlld/nerveEntity/ModelParameter.java
@@ -14,6 +14,24 @@ public class ModelParameter {
//神经远模型参数
private List> depthNerves = new ArrayList<>();//隐层神经元
private List outNevers = new ArrayList<>();//输出神经元
+ private List dymNerveStudies = new ArrayList<>();//动态神经元隐层
+ private DymNerveStudy dymOutNerveStudy = new DymNerveStudy();//动态神经元输出层
+
+ public List getDymNerveStudies() {
+ return dymNerveStudies;
+ }
+
+ public void setDymNerveStudies(List dymNerveStudies) {
+ this.dymNerveStudies = dymNerveStudies;
+ }
+
+ public DymNerveStudy getDymOutNerveStudy() {
+ return dymOutNerveStudy;
+ }
+
+ public void setDymOutNerveStudy(DymNerveStudy dymOutNerveStudy) {
+ this.dymOutNerveStudy = dymOutNerveStudy;
+ }
public List> getDepthNerves() {
return depthNerves;
diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java
index 9b88a26..9b2f2d7 100644
--- a/src/main/java/org/wlld/nerveEntity/Nerve.java
+++ b/src/main/java/org/wlld/nerveEntity/Nerve.java
@@ -39,6 +39,14 @@ public abstract class Nerve {
return dendrites;
}
+ public Matrix getNerveMatrix() {
+ return nerveMatrix;
+ }
+
+ public void setNerveMatrix(Matrix nerveMatrix) {
+ this.nerveMatrix = nerveMatrix;
+ }
+
public void setDendrites(Map dendrites) {
this.dendrites = dendrites;
}
diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java
index 4353ab0..bb1397f 100644
--- a/src/main/java/org/wlld/nerveEntity/OutNerve.java
+++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java
@@ -69,7 +69,7 @@ public class OutNerve extends Nerve {
matrixF = new Matrix(myMatrix.getX(), myMatrix.getY());
}
if (isKernelStudy) {//回传
- // System.out.println(myMatrix.getString());
+ // System.out.println(myMatrix.getString());
for (Map.Entry entry : E.entrySet()) {
double g;
if (entry.getValue() > 0.5) {//正模板