增加TANH 激活函数,默认激活函数使用TANH

pull/7/head
Administrator 5 years ago
parent 3661bebdee
commit 49d3c71d39

@ -0,0 +1,18 @@
package org.wlld.function;
import org.wlld.i.ActiveFunction;
import org.wlld.tools.ArithUtil;
public class Tanh implements ActiveFunction {
@Override
public double function(double x) {
double son = ArithUtil.sub(Math.exp(x), Math.exp(-x));
double mother = ArithUtil.add(Math.exp(x), Math.exp(-x));
return ArithUtil.div(son, mother);
}
@Override
public double functionG(double out) {
return ArithUtil.sub(1, Math.pow(function(out), 2));
}
}

@ -6,6 +6,8 @@ import org.wlld.config.Classifier;
import org.wlld.config.StudyPattern;
import org.wlld.function.ReLu;
import org.wlld.function.Sigmod;
import org.wlld.function.Tanh;
import org.wlld.i.ActiveFunction;
import org.wlld.imageRecognition.border.*;
import org.wlld.imageRecognition.modelEntity.BoxList;
import org.wlld.imageRecognition.modelEntity.KBorder;
@ -49,7 +51,7 @@ public class TempleConfig {
private double avg = 0;//覆盖均值
private int sensoryNerveNub;//输入神经元个数
private boolean isShowLog = false;
private ActiveFunction activeFunction = new Tanh();
public boolean isAccurate() {
return isAccurate;
}
@ -58,6 +60,10 @@ public class TempleConfig {
isAccurate = accurate;
}
public void setActiveFunction(ActiveFunction activeFunction) {
this.activeFunction = activeFunction;
}
public double getAvg() {
return avg;
}
@ -247,7 +253,7 @@ public class TempleConfig {
private void initNerveManager(boolean initPower, int sensoryNerveNub
, int deep) throws Exception {
nerveManager = new NerveManager(sensoryNerveNub, 9,
classificationNub, deep, new Sigmod(), false, isAccurate);
classificationNub, deep, activeFunction, false, isAccurate);
nerveManager.init(initPower, false, isShowLog);
}
@ -461,9 +467,6 @@ public class TempleConfig {
vectorK.insertKMatrix(modelParameter.getMatrixK());
break;
case Classifier.DNN:
// ModelParameter modelParameter1 = new ModelParameter();
// modelParameter1.setDepthNerves(modelParameter.getDepthNerves());
// modelParameter1.setOutNerves(modelParameter.getOutNerves());
nerveManager.insertModelParameter(modelParameter);
normalization = new Normalization();
normalization.setAvg(modelParameter.getDnnAvg());

@ -46,7 +46,7 @@ public class OutNerve extends Nerve {
if (E.containsKey(getId())) {
this.E = E.get(getId());
} else {
this.E = 0;
this.E = -1;
}
if (isShowLog) {
System.out.println("E==" + this.E + ",out==" + out + ",nerveId==" + getId());

@ -127,44 +127,44 @@ public class HelloWorld {
templeConfig.setClassifier(Classifier.DNN);
templeConfig.isShowLog(true);
templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4);
// ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA2, ModelParameter.class);
// templeConfig.insertModel(modelParameter2);
ModelParameter modelParameter2 = JSON.parseObject(ModelData.DATA3, ModelParameter.class);
templeConfig.insertModel(modelParameter2);
Operation operation = new Operation(templeConfig);
//a b c d 物品 e是背景
// 一阶段
for (int j = 0; j < 1; j++) {
for (int i = 1; i < 1900; i++) {//一阶段
System.out.println("study1===================" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
//将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习
//第二次学习的时候,第三个参数必须是 true
operation.learning(a, 1, false);
operation.learning(b, 2, false);
operation.learning(c, 3, false);
operation.learning(d, 4, false);
}
}
// for (int j = 0; j < 1; j++) {
// for (int i = 1; i < 1900; i++) {//一阶段
// System.out.println("study1===================" + i);
// //读取本地URL地址图片,并转化成矩阵
// Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
// Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
// Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
// Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
// //将图像矩阵和标注加入进行学习Accuracy_Pattern 模式 进行第二次学习
// //第二次学习的时候,第三个参数必须是 true
// operation.learning(a, 1, false);
// operation.learning(b, 2, false);
// operation.learning(c, 3, false);
// operation.learning(d, 4, false);
// }
// }
//二阶段
for (int i = 1; i < 1900; i++) {
System.out.println("avg==" + i);
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
operation.normalization(a, templeConfig.getConvolutionNerveManager());
operation.normalization(b, templeConfig.getConvolutionNerveManager());
operation.normalization(c, templeConfig.getConvolutionNerveManager());
operation.normalization(d, templeConfig.getConvolutionNerveManager());
}
templeConfig.getNormalization().avg();
// for (int i = 1; i < 1900; i++) {
// System.out.println("avg==" + i);
// Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
// Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");
// Matrix c = picture.getImageMatrixByLocal("D:\\share\\picture/c" + i + ".jpg");
// Matrix d = picture.getImageMatrixByLocal("D:\\share\\picture/d" + i + ".jpg");
// operation.normalization(a, templeConfig.getConvolutionNerveManager());
// operation.normalization(b, templeConfig.getConvolutionNerveManager());
// operation.normalization(c, templeConfig.getConvolutionNerveManager());
// operation.normalization(d, templeConfig.getConvolutionNerveManager());
// }
// templeConfig.getNormalization().avg();
for (int j = 0; j < 1; j++) {
for (int i = 1; i < 1900; i++) {
System.out.println("study2==================" + i);
System.out.println("j==" + j + ",study2==================" + i);
//读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save