增加隐层不同层学习率设置

pull/1/head
lidapeng 5 years ago
parent f8c9f6d490
commit 577649f19c

@ -16,6 +16,15 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.51</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<pluginManagement>
<plugins>

@ -1,15 +0,0 @@
package org.wlld;
import org.wlld.i.OutBack;
/**
* @author lidapeng
* @description
* @date 1:19 2019/12/24
*/
public class Test implements OutBack {
@Override
public void getBack(double out, int id, long eventId) {
System.out.println("out==" + out + ",id==" + id + ",eventId==" + eventId);
}
}

@ -116,6 +116,16 @@ public class TempleConfig {
nerveManager.setStudyPoint(studyPoint);
}
public void setStudyList(List<Double> list) {//设置每一层不同的学习率
if (studyPattern == StudyPattern.Accuracy_Pattern) {
//给卷积层设置层学习率
convolutionNerveManager.setStudyList(list);
} else if (studyPattern == StudyPattern.Speed_Pattern) {
//给全连接层设置学习率
nerveManager.setStudyList(list);
}
}
public int getRow() {
return row;
}
@ -131,6 +141,9 @@ public class TempleConfig {
//注入模型参数
public void insertModel(ModelParameter modelParameter) throws Exception {
nerveManager.insertModelParameter(modelParameter);
if (studyPattern == StudyPattern.Accuracy_Pattern) {
convolutionNerveManager.insertModelParameter(modelParameter);
}
}
public void setCutThreshold(double cutThreshold) {

@ -31,6 +31,15 @@ public class NerveManager {
private ActiveFunction activeFunction;
private Map<Integer, Matrix> matrixMap = new HashMap<>();//主键与期望矩阵的映射
private boolean isDynamic;//是否是动态神经网络
private List<Double> studyList = new ArrayList<>();
public List<Double> getStudyList() {//查看每一次的学习率
return studyList;
}
public void setStudyList(List<Double> studyList) {//设置每一层的学习率
this.studyList = studyList;
}
public void setMatrixMap(Map<Integer, Matrix> matrixMap) {
this.matrixMap = matrixMap;
@ -42,7 +51,7 @@ public class NerveManager {
public void setStudyPoint(double studyPoint) throws Exception {
//设置学习率
if (studyPoint < 1 && studyPoint > 0) {
if (studyPoint <= 1 && studyPoint > 0) {
this.studyPoint = studyPoint;
} else {
throw new Exception("studyPoint Values range from 0 to 1");
@ -119,9 +128,10 @@ public class NerveManager {
}
public void insertModelParameter(ModelParameter modelParameter) throws Exception {
insertBpModelParameter(modelParameter);//全连接层注入参数
if (isDynamic) {
insertConvolutionModelParameter(modelParameter);
insertConvolutionModelParameter(modelParameter);//动态神经元注入
} else {
insertBpModelParameter(modelParameter);//全连接层注入参数
}
}
@ -278,6 +288,13 @@ public class NerveManager {
private void initDepthNerve(boolean isMatrix) throws Exception {//初始化隐层神经元1
for (int i = 0; i < hiddenDepth; i++) {//遍历深度
List<Nerve> hiddenNerveList = new ArrayList<>();
double studyPoint = this.studyPoint;
if (studyList.contains(i)) {//加载每一层的学习率
studyPoint = studyList.get(i);
}
if (studyPoint <= 0 || studyPoint > 1) {
throw new Exception("studyPoint Values range from 0 to 1");
}
for (int j = 1; j < hiddenNerverNub + 1; j++) {//遍历同级
int upNub = 0;
int downNub = 0;

@ -1,57 +0,0 @@
package org.wlld.test;
import org.wlld.MatrixTools.Matrix;
import org.wlld.imageRecognition.Operation;
import org.wlld.imageRecognition.Picture;
import org.wlld.imageRecognition.TempleConfig;
import org.wlld.tools.ArithUtil;
/**
* @author lidapeng
* @description
* @date 2:11 2020/1/7
*/
public class Test {
public static Matrix E;
public static Matrix F;
static {
try {
E = getE(true);
F = getE(false);
} catch (Exception e) {
e.printStackTrace();
}
}
public static void main(String[] args) throws Exception {
double d = ArithUtil.div(3204, 4032);
int a = (int) (d * 5);
System.out.println(a);
}
public static void test() throws Exception {
}
public static Matrix getE(boolean isRight) throws Exception {
Matrix matrix = new Matrix(5, 4);
String name;
if (isRight) {
name = "[10,10,10,0]#" +
"[10,10,10,0]#" +
"[10,10,10,0]#" +
"[10,10,10,0]#" +
"[10,10,10,0]#";
} else {
name = "[1,1,1,0]#" +
"[1,1,1,0]#" +
"[1,1,1,0]#" +
"[1,1,1,0]#" +
"[1,1,1,0]#";
}
matrix.setAll(name);
return matrix;
}
}

@ -1,6 +1,7 @@
package org.wlld;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.wlld.MatrixTools.Matrix;
import org.wlld.config.StudyPattern;
import org.wlld.imageRecognition.Operation;
@ -13,14 +14,32 @@ import java.util.HashMap;
import java.util.Map;
/**
* !
* @author lidapeng
* @description
* @date 11:35 2020/1/18
*/
public class HelloWorld {
public static void main(String[] args) throws Exception {
testPic2();
//testPic2();
testModel();
}
public static void testModel() throws Exception {
// 模型参数获取及注入
TempleConfig templeConfig = getTemple(true, StudyPattern.Accuracy_Pattern);
ModelParameter modelParameter1 = templeConfig.getModel();
String model = JSON.toJSONString(modelParameter1);
System.out.println(model);
TempleConfig templeConfig2 = getTemple(false, StudyPattern.Accuracy_Pattern);
ModelParameter modelParameter3 = JSONObject.parseObject(model, ModelParameter.class);
templeConfig2.insertModel(modelParameter3);
ModelParameter modelParameter2 = templeConfig2.getModel();
String model2 = JSON.toJSONString(modelParameter2);
System.out.println(model2);
}
public static void testPic2() throws Exception {
public static void testPic2() throws Exception {//测试Accuracy_Pattern 模式学习
Picture picture = new Picture();
TempleConfig templeConfig = getTemple(true, StudyPattern.Accuracy_Pattern);
Operation operation = new Operation(templeConfig);
@ -53,7 +72,7 @@ public class HelloWorld {
operation.look(wrong, 3);
}
public static void testPic() throws Exception {
public static void testPic() throws Exception {//测试SPEED模式学习
//初始化图像转矩阵类
Picture picture = new Picture();
//初始化配置模板类
@ -79,13 +98,13 @@ public class HelloWorld {
//获取模型MODLE
ModelParameter modelParameter = templeConfig.getModel();
//将模型MODEL转化成JSON 字符串
//String model = JSON.toJSONString(modelParameter);
String model = JSON.toJSONString(modelParameter);
//将JSON字符串转化为模型MODEL
//ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class);
ModelParameter modelParameter1 = JSONObject.parseObject(model, ModelParameter.class);
//初始化模型配置
TempleConfig templeConfig1 = getTemple(false, StudyPattern.Speed_Pattern);
//注入之前学习结果的模型MODEL到配置模版里面
templeConfig1.insertModel(modelParameter);
templeConfig1.insertModel(modelParameter1);
//将配置模板配置到运算类
Operation operation1 = new Operation(templeConfig1);
//获取本地图片字节码转化成降纬后的灰度矩阵
Loading…
Cancel
Save