diff --git a/pom.xml b/pom.xml
index 712ab8a..affac86 100644
--- a/pom.xml
+++ b/pom.xml
@@ -16,6 +16,15 @@
1.8
1.8
+
+
+ com.alibaba
+ fastjson
+ 1.2.51
+ test
+
+
+
diff --git a/src/main/java/org/wlld/Test.java b/src/main/java/org/wlld/Test.java
deleted file mode 100644
index 7548395..0000000
--- a/src/main/java/org/wlld/Test.java
+++ /dev/null
@@ -1,15 +0,0 @@
-package org.wlld;
-
-import org.wlld.i.OutBack;
-
-/**
- * @author lidapeng
- * @description
- * @date 1:19 下午 2019/12/24
- */
-public class Test implements OutBack {
- @Override
- public void getBack(double out, int id, long eventId) {
- System.out.println("out==" + out + ",id==" + id + ",eventId==" + eventId);
- }
-}
diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java
index 1c0770d..7778387 100644
--- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java
+++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java
@@ -116,6 +116,16 @@ public class TempleConfig {
nerveManager.setStudyPoint(studyPoint);
}
+ public void setStudyList(List list) {//设置每一层不同的学习率
+ if (studyPattern == StudyPattern.Accuracy_Pattern) {
+ //给卷积层设置层学习率
+ convolutionNerveManager.setStudyList(list);
+ } else if (studyPattern == StudyPattern.Speed_Pattern) {
+ //给全连接层设置学习率
+ nerveManager.setStudyList(list);
+ }
+ }
+
public int getRow() {
return row;
}
@@ -131,6 +141,9 @@ public class TempleConfig {
//注入模型参数
public void insertModel(ModelParameter modelParameter) throws Exception {
nerveManager.insertModelParameter(modelParameter);
+ if (studyPattern == StudyPattern.Accuracy_Pattern) {
+ convolutionNerveManager.insertModelParameter(modelParameter);
+ }
}
public void setCutThreshold(double cutThreshold) {
diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java
index 52726c5..52dbcbe 100644
--- a/src/main/java/org/wlld/nerveCenter/NerveManager.java
+++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java
@@ -31,6 +31,15 @@ public class NerveManager {
private ActiveFunction activeFunction;
private Map matrixMap = new HashMap<>();//主键与期望矩阵的映射
private boolean isDynamic;//是否是动态神经网络
+ private List studyList = new ArrayList<>();
+
+ public List getStudyList() {//查看每一次的学习率
+ return studyList;
+ }
+
+ public void setStudyList(List studyList) {//设置每一层的学习率
+ this.studyList = studyList;
+ }
public void setMatrixMap(Map matrixMap) {
this.matrixMap = matrixMap;
@@ -42,7 +51,7 @@ public class NerveManager {
public void setStudyPoint(double studyPoint) throws Exception {
//设置学习率
- if (studyPoint < 1 && studyPoint > 0) {
+ if (studyPoint <= 1 && studyPoint > 0) {
this.studyPoint = studyPoint;
} else {
throw new Exception("studyPoint Values range from 0 to 1");
@@ -119,9 +128,10 @@ public class NerveManager {
}
public void insertModelParameter(ModelParameter modelParameter) throws Exception {
- insertBpModelParameter(modelParameter);//全连接层注入参数
if (isDynamic) {
- insertConvolutionModelParameter(modelParameter);
+ insertConvolutionModelParameter(modelParameter);//动态神经元注入
+ } else {
+ insertBpModelParameter(modelParameter);//全连接层注入参数
}
}
@@ -278,6 +288,13 @@ public class NerveManager {
private void initDepthNerve(boolean isMatrix) throws Exception {//初始化隐层神经元1
for (int i = 0; i < hiddenDepth; i++) {//遍历深度
List hiddenNerveList = new ArrayList<>();
+ double studyPoint = this.studyPoint;
+ if (studyList.contains(i)) {//加载每一层的学习率
+ studyPoint = studyList.get(i);
+ }
+ if (studyPoint <= 0 || studyPoint > 1) {
+ throw new Exception("studyPoint Values range from 0 to 1");
+ }
for (int j = 1; j < hiddenNerverNub + 1; j++) {//遍历同级
int upNub = 0;
int downNub = 0;
diff --git a/src/main/java/org/wlld/test/Test.java b/src/main/java/org/wlld/test/Test.java
deleted file mode 100644
index 378e272..0000000
--- a/src/main/java/org/wlld/test/Test.java
+++ /dev/null
@@ -1,57 +0,0 @@
-package org.wlld.test;
-
-import org.wlld.MatrixTools.Matrix;
-import org.wlld.imageRecognition.Operation;
-import org.wlld.imageRecognition.Picture;
-import org.wlld.imageRecognition.TempleConfig;
-import org.wlld.tools.ArithUtil;
-
-/**
- * @author lidapeng
- * @description
- * @date 2:11 下午 2020/1/7
- */
-public class Test {
- public static Matrix E;
- public static Matrix F;
-
- static {
- try {
- E = getE(true);
- F = getE(false);
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
-
- public static void main(String[] args) throws Exception {
- double d = ArithUtil.div(3204, 4032);
- int a = (int) (d * 5);
- System.out.println(a);
- }
-
- public static void test() throws Exception {
-
- }
-
- public static Matrix getE(boolean isRight) throws Exception {
- Matrix matrix = new Matrix(5, 4);
- String name;
- if (isRight) {
- name = "[10,10,10,0]#" +
- "[10,10,10,0]#" +
- "[10,10,10,0]#" +
- "[10,10,10,0]#" +
- "[10,10,10,0]#";
- } else {
- name = "[1,1,1,0]#" +
- "[1,1,1,0]#" +
- "[1,1,1,0]#" +
- "[1,1,1,0]#" +
- "[1,1,1,0]#";
- }
-
- matrix.setAll(name);
- return matrix;
- }
-}
diff --git a/src/main/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java
similarity index 79%
rename from src/main/java/org/wlld/HelloWorld.java
rename to src/test/java/org/wlld/HelloWorld.java
index 6a022d0..80d759a 100644
--- a/src/main/java/org/wlld/HelloWorld.java
+++ b/src/test/java/org/wlld/HelloWorld.java
@@ -1,6 +1,7 @@
package org.wlld;
-
+import com.alibaba.fastjson.JSON;
+import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix;
import org.wlld.config.StudyPattern;
import org.wlld.imageRecognition.Operation;
@@ -13,14 +14,32 @@ import java.util.HashMap;
import java.util.Map;
/**
- * 测试入口类!
+ * @author lidapeng
+ * @description 测试入口类
+ * @date 11:35 上午 2020/1/18
*/
public class HelloWorld {
public static void main(String[] args) throws Exception {
- testPic2();
+ //testPic2();
+ testModel();
+ }
+
+ public static void testModel() throws Exception {
+ // 模型参数获取及注入
+ TempleConfig templeConfig = getTemple(true, StudyPattern.Accuracy_Pattern);
+ ModelParameter modelParameter1 = templeConfig.getModel();
+ String model = JSON.toJSONString(modelParameter1);
+ System.out.println(model);
+ TempleConfig templeConfig2 = getTemple(false, StudyPattern.Accuracy_Pattern);
+ ModelParameter modelParameter3 = JSONObject.parseObject(model, ModelParameter.class);
+ templeConfig2.insertModel(modelParameter3);
+ ModelParameter modelParameter2 = templeConfig2.getModel();
+ String model2 = JSON.toJSONString(modelParameter2);
+ System.out.println(model2);
+
}
- public static void testPic2() throws Exception {
+ public static void testPic2() throws Exception {//测试Accuracy_Pattern 模式学习
Picture picture = new Picture();
TempleConfig templeConfig = getTemple(true, StudyPattern.Accuracy_Pattern);
Operation operation = new Operation(templeConfig);
@@ -53,7 +72,7 @@ public class HelloWorld {
operation.look(wrong, 3);
}
- public static void testPic() throws Exception {
+ public static void testPic() throws Exception {//测试SPEED模式学习
//初始化图像转矩阵类
Picture picture = new Picture();
//初始化配置模板类
@@ -79,13 +98,13 @@ public class HelloWorld {
//获取模型MODLE
ModelParameter modelParameter = templeConfig.getModel();
//将模型MODEL转化成JSON 字符串
- //String model = JSON.toJSONString(modelParameter);
+ String model = JSON.toJSONString(modelParameter);
//将JSON字符串转化为模型MODEL
- //ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class);
+ ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class);
//初始化模型配置
TempleConfig templeConfig1 = getTemple(false, StudyPattern.Speed_Pattern);
//注入之前学习结果的模型MODEL到配置模版里面
- templeConfig1.insertModel(modelParameter);
+ templeConfig1.insertModel(modelParameter1);
//将配置模板配置到运算类
Operation operation1 = new Operation(templeConfig1);
//获取本地图片字节码转化成降纬后的灰度矩阵