增加TANH 激活函数,默认激活函数使用TANH

pull/7/head
Administrator 5 years ago
parent 3661bebdee
commit 49d3c71d39

@ -0,0 +1,18 @@
package org.wlld.function;
import org.wlld.i.ActiveFunction;
import org.wlld.tools.ArithUtil;
public class Tanh implements ActiveFunction {
@Override
public double function(double x) {
double son = ArithUtil.sub(Math.exp(x), Math.exp(-x));
double mother = ArithUtil.add(Math.exp(x), Math.exp(-x));
return ArithUtil.div(son, mother);
}
@Override
public double functionG(double out) {
return ArithUtil.sub(1, Math.pow(function(out), 2));
}
}

@ -6,6 +6,8 @@ import org.wlld.config.Classifier;
import org.wlld.config.StudyPattern; import org.wlld.config.StudyPattern;
import org.wlld.function.ReLu; import org.wlld.function.ReLu;
import org.wlld.function.Sigmod; import org.wlld.function.Sigmod;
import org.wlld.function.Tanh;
import org.wlld.i.ActiveFunction;
import org.wlld.imageRecognition.border.*; import org.wlld.imageRecognition.border.*;
import org.wlld.imageRecognition.modelEntity.BoxList; import org.wlld.imageRecognition.modelEntity.BoxList;
import org.wlld.imageRecognition.modelEntity.KBorder; import org.wlld.imageRecognition.modelEntity.KBorder;
@ -49,7 +51,7 @@ public class TempleConfig {
private double avg = 0;//覆盖均值 private double avg = 0;//覆盖均值
private int sensoryNerveNub;//输入神经元个数 private int sensoryNerveNub;//输入神经元个数
private boolean isShowLog = false; private boolean isShowLog = false;
private ActiveFunction activeFunction = new Tanh();
public boolean isAccurate() { public boolean isAccurate() {
return isAccurate; return isAccurate;
} }
@ -58,6 +60,10 @@ public class TempleConfig {
isAccurate = accurate; isAccurate = accurate;
} }
public void setActiveFunction(ActiveFunction activeFunction) {
this.activeFunction = activeFunction;
}
public double getAvg() { public double getAvg() {
return avg; return avg;
} }
@ -247,7 +253,7 @@ public class TempleConfig {
private void initNerveManager(boolean initPower, int sensoryNerveNub private void initNerveManager(boolean initPower, int sensoryNerveNub
, int deep) throws Exception { , int deep) throws Exception {
nerveManager = new NerveManager(sensoryNerveNub, 9, nerveManager = new NerveManager(sensoryNerveNub, 9,
classificationNub, deep, new Sigmod(), false, isAccurate); classificationNub, deep, activeFunction, false, isAccurate);
nerveManager.init(initPower, false, isShowLog); nerveManager.init(initPower, false, isShowLog);
} }
@ -461,9 +467,6 @@ public class TempleConfig {
vectorK.insertKMatrix(modelParameter.getMatrixK()); vectorK.insertKMatrix(modelParameter.getMatrixK());
break; break;
case Classifier.DNN: case Classifier.DNN:
// ModelParameter modelParameter1 = new ModelParameter();
// modelParameter1.setDepthNerves(modelParameter.getDepthNerves());
// modelParameter1.setOutNerves(modelParameter.getOutNerves());
nerveManager.insertModelParameter(modelParameter); nerveManager.insertModelParameter(modelParameter);
normalization = new Normalization(); normalization = new Normalization();
normalization.setAvg(modelParameter.getDnnAvg()); normalization.setAvg(modelParameter.getDnnAvg());

@ -46,7 +46,7 @@ public class OutNerve extends Nerve {
if (E.containsKey(getId())) { if (E.containsKey(getId())) {
this.E = E.get(getId()); this.E = E.get(getId());
} else { } else {
this.E = 0; this.E = -1;
} }
if (isShowLog) { if (isShowLog) {
System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + getId()); System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + getId());

@ -127,44 +127,44 @@ public class HelloWorld {
templeConfig.setClassifier(Classifier.DNN); templeConfig.setClassifier(Classifier.DNN);
templeConfig.isShowLog(true); templeConfig.isShowLog(true);
templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4); templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4);
// ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA2, ModelParameter.class); ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA3, ModelParameter.class);
// templeConfig.insertModel(modelParameter2); templeConfig.insertModel(modelParameter2);
Operation operation = new Operation(templeConfig); Operation operation = new Operation(templeConfig);
//a b c d 物品 e是背景 //a b c d 物品 e是背景
// 一阶段 // 一阶段
for (int j = 0; j < 1; j++) { // for (int j = 0; j < 1; j++) {
for (int i = 1; i < 1900; i++) {//一阶段 // for (int i = 1; i < 1900; i++) {//一阶段
System.out.println("study1===================" + i); // System.out.println("study1===================" + i);
//读取本地URL地址图片,并转化成矩阵 // //读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); // Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg"); // Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg"); // Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg"); // Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
//将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习 // //将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习
//第二次学习的时候,第三个参数必须是 true // //第二次学习的时候,第三个参数必须是 true
operation.learning(a, 1, false); // operation.learning(a, 1, false);
operation.learning(b, 2, false); // operation.learning(b, 2, false);
operation.learning(c, 3, false); // operation.learning(c, 3, false);
operation.learning(d, 4, false); // operation.learning(d, 4, false);
} // }
} // }
//二阶段 //二阶段
for (int i = 1; i < 1900; i++) { // for (int i = 1; i < 1900; i++) {
System.out.println("avg==" + i); // System.out.println("avg==" + i);
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); // Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg"); // Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg"); // Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg"); // Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
operation.normalization(a, templeConfig.getConvolutionNerveManager()); // operation.normalization(a, templeConfig.getConvolutionNerveManager());
operation.normalization(b, templeConfig.getConvolutionNerveManager()); // operation.normalization(b, templeConfig.getConvolutionNerveManager());
operation.normalization(c, templeConfig.getConvolutionNerveManager()); // operation.normalization(c, templeConfig.getConvolutionNerveManager());
operation.normalization(d, templeConfig.getConvolutionNerveManager()); // operation.normalization(d, templeConfig.getConvolutionNerveManager());
} // }
templeConfig.getNormalization().avg(); // templeConfig.getNormalization().avg();
for (int j = 0; j < 1; j++) { for (int j = 0; j < 1; j++) {
for (int i = 1; i < 1900; i++) { for (int i = 1; i < 1900; i++) {
System.out.println("study2==================" + i); System.out.println("j==" + j + ",study2==================" + i);
//读取本地URL地址图片,并转化成矩阵 //读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg"); Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save