diff --git a/pom.xml b/pom.xml index 712ab8a..affac86 100644 --- a/pom.xml +++ b/pom.xml @@ -16,6 +16,15 @@ 1.8 1.8 + + + com.alibaba + fastjson + 1.2.51 + test + + + diff --git a/src/main/java/org/wlld/Test.java b/src/main/java/org/wlld/Test.java deleted file mode 100644 index 7548395..0000000 --- a/src/main/java/org/wlld/Test.java +++ /dev/null @@ -1,15 +0,0 @@ -package org.wlld; - -import org.wlld.i.OutBack; - -/** - * @author lidapeng - * @description - * @date 1:19 下午 2019/12/24 - */ -public class Test implements OutBack { - @Override - public void getBack(double out, int id, long eventId) { - System.out.println("out==" + out + ",id==" + id + ",eventId==" + eventId); - } -} diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index 1c0770d..7778387 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -116,6 +116,16 @@ public class TempleConfig { nerveManager.setStudyPoint(studyPoint); } + public void setStudyList(List list) {//设置每一层不同的学习率 + if (studyPattern == StudyPattern.Accuracy_Pattern) { + //给卷积层设置层学习率 + convolutionNerveManager.setStudyList(list); + } else if (studyPattern == StudyPattern.Speed_Pattern) { + //给全连接层设置学习率 + nerveManager.setStudyList(list); + } + } + public int getRow() { return row; } @@ -131,6 +141,9 @@ public class TempleConfig { //注入模型参数 public void insertModel(ModelParameter modelParameter) throws Exception { nerveManager.insertModelParameter(modelParameter); + if (studyPattern == StudyPattern.Accuracy_Pattern) { + convolutionNerveManager.insertModelParameter(modelParameter); + } } public void setCutThreshold(double cutThreshold) { diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java index 52726c5..52dbcbe 100644 --- a/src/main/java/org/wlld/nerveCenter/NerveManager.java +++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java @@ -31,6 +31,15 @@ public class NerveManager { private ActiveFunction activeFunction; private Map matrixMap = new HashMap<>();//主键与期望矩阵的映射 private boolean isDynamic;//是否是动态神经网络 + private List studyList = new ArrayList<>(); + + public List getStudyList() {//查看每一次的学习率 + return studyList; + } + + public void setStudyList(List studyList) {//设置每一层的学习率 + this.studyList = studyList; + } public void setMatrixMap(Map matrixMap) { this.matrixMap = matrixMap; @@ -42,7 +51,7 @@ public class NerveManager { public void setStudyPoint(double studyPoint) throws Exception { //设置学习率 - if (studyPoint < 1 && studyPoint > 0) { + if (studyPoint <= 1 && studyPoint > 0) { this.studyPoint = studyPoint; } else { throw new Exception("studyPoint Values range from 0 to 1"); @@ -119,9 +128,10 @@ public class NerveManager { } public void insertModelParameter(ModelParameter modelParameter) throws Exception { - insertBpModelParameter(modelParameter);//全连接层注入参数 if (isDynamic) { - insertConvolutionModelParameter(modelParameter); + insertConvolutionModelParameter(modelParameter);//动态神经元注入 + } else { + insertBpModelParameter(modelParameter);//全连接层注入参数 } } @@ -278,6 +288,13 @@ public class NerveManager { private void initDepthNerve(boolean isMatrix) throws Exception {//初始化隐层神经元1 for (int i = 0; i < hiddenDepth; i++) {//遍历深度 List hiddenNerveList = new ArrayList<>(); + double studyPoint = this.studyPoint; + if (studyList.contains(i)) {//加载每一层的学习率 + studyPoint = studyList.get(i); + } + if (studyPoint <= 0 || studyPoint > 1) { + throw new Exception("studyPoint Values range from 0 to 1"); + } for (int j = 1; j < hiddenNerverNub + 1; j++) {//遍历同级 int upNub = 0; int downNub = 0; diff --git a/src/main/java/org/wlld/test/Test.java b/src/main/java/org/wlld/test/Test.java deleted file mode 100644 index 378e272..0000000 --- a/src/main/java/org/wlld/test/Test.java +++ /dev/null @@ -1,57 +0,0 @@ -package org.wlld.test; - -import org.wlld.MatrixTools.Matrix; -import org.wlld.imageRecognition.Operation; -import org.wlld.imageRecognition.Picture; -import org.wlld.imageRecognition.TempleConfig; -import org.wlld.tools.ArithUtil; - -/** - * @author lidapeng - * @description - * @date 2:11 下午 2020/1/7 - */ -public class Test { - public static Matrix E; - public static Matrix F; - - static { - try { - E = getE(true); - F = getE(false); - } catch (Exception e) { - e.printStackTrace(); - } - } - - public static void main(String[] args) throws Exception { - double d = ArithUtil.div(3204, 4032); - int a = (int) (d * 5); - System.out.println(a); - } - - public static void test() throws Exception { - - } - - public static Matrix getE(boolean isRight) throws Exception { - Matrix matrix = new Matrix(5, 4); - String name; - if (isRight) { - name = "[10,10,10,0]#" + - "[10,10,10,0]#" + - "[10,10,10,0]#" + - "[10,10,10,0]#" + - "[10,10,10,0]#"; - } else { - name = "[1,1,1,0]#" + - "[1,1,1,0]#" + - "[1,1,1,0]#" + - "[1,1,1,0]#" + - "[1,1,1,0]#"; - } - - matrix.setAll(name); - return matrix; - } -} diff --git a/src/main/java/org/wlld/HelloWorld.java b/src/test/java/org/wlld/HelloWorld.java similarity index 79% rename from src/main/java/org/wlld/HelloWorld.java rename to src/test/java/org/wlld/HelloWorld.java index 6a022d0..80d759a 100644 --- a/src/main/java/org/wlld/HelloWorld.java +++ b/src/test/java/org/wlld/HelloWorld.java @@ -1,6 +1,7 @@ package org.wlld; - +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; import org.wlld.MatrixTools.Matrix; import org.wlld.config.StudyPattern; import org.wlld.imageRecognition.Operation; @@ -13,14 +14,32 @@ import java.util.HashMap; import java.util.Map; /** - * 测试入口类! + * @author lidapeng + * @description 测试入口类 + * @date 11:35 上午 2020/1/18 */ public class HelloWorld { public static void main(String[] args) throws Exception { - testPic2(); + //testPic2(); + testModel(); + } + + public static void testModel() throws Exception { + // 模型参数获取及注入 + TempleConfig templeConfig = getTemple(true, StudyPattern.Accuracy_Pattern); + ModelParameter modelParameter1 = templeConfig.getModel(); + String model = JSON.toJSONString(modelParameter1); + System.out.println(model); + TempleConfig templeConfig2 = getTemple(false, StudyPattern.Accuracy_Pattern); + ModelParameter modelParameter3 = JSONObject.parseObject(model, ModelParameter.class); + templeConfig2.insertModel(modelParameter3); + ModelParameter modelParameter2 = templeConfig2.getModel(); + String model2 = JSON.toJSONString(modelParameter2); + System.out.println(model2); + } - public static void testPic2() throws Exception { + public static void testPic2() throws Exception {//测试Accuracy_Pattern 模式学习 Picture picture = new Picture(); TempleConfig templeConfig = getTemple(true, StudyPattern.Accuracy_Pattern); Operation operation = new Operation(templeConfig); @@ -53,7 +72,7 @@ public class HelloWorld { operation.look(wrong, 3); } - public static void testPic() throws Exception { + public static void testPic() throws Exception {//测试SPEED模式学习 //初始化图像转矩阵类 Picture picture = new Picture(); //初始化配置模板类 @@ -79,13 +98,13 @@ public class HelloWorld { //获取模型MODLE ModelParameter modelParameter = templeConfig.getModel(); //将模型MODEL转化成JSON 字符串 - //String model = JSON.toJSONString(modelParameter); + String model = JSON.toJSONString(modelParameter); //将JSON字符串转化为模型MODEL - //ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class); + ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class); //初始化模型配置 TempleConfig templeConfig1 = getTemple(false, StudyPattern.Speed_Pattern); //注入之前学习结果的模型MODEL到配置模版里面 - templeConfig1.insertModel(modelParameter); + templeConfig1.insertModel(modelParameter1); //将配置模板配置到运算类 Operation operation1 = new Operation(templeConfig1); //获取本地图片字节码转化成降纬后的灰度矩阵