增加TANH 激活函数,默认激活函数使用TANH

pull/10/head
Administrator 5 years ago
parent 73ba62b42a
commit ec7dc5fd5c

@ -296,7 +296,9 @@ public class Operation {//进行计算
private void lvq(int tagging, Matrix myMatrix) throws Exception {//LVQ学习 private void lvq(int tagging, Matrix myMatrix) throws Exception {//LVQ学习
LVQ lvq = templeConfig.getLvq(); LVQ lvq = templeConfig.getLvq();
Matrix vector = MatrixOperation.matrixToVector(myMatrix, true); Matrix vector = MatrixOperation.matrixToVector(myMatrix, true);
System.out.println(vector.getString()); if (templeConfig.isShowLog()) {
System.out.println(vector.getString());
}
MatrixBody matrixBody = new MatrixBody(); MatrixBody matrixBody = new MatrixBody();
matrixBody.setMatrix(vector); matrixBody.setMatrix(vector);
matrixBody.setId(tagging); matrixBody.setId(tagging);

@ -262,7 +262,7 @@ public class TempleConfig {
private void initNerveManager(boolean initPower, int sensoryNerveNub private void initNerveManager(boolean initPower, int sensoryNerveNub
, int deep, double studyPoint) throws Exception { , int deep, double studyPoint) throws Exception {
nerveManager = new NerveManager(sensoryNerveNub, 9, nerveManager = new NerveManager(sensoryNerveNub, 6,
classificationNub, deep, activeFunction, false, isAccurate, studyPoint); classificationNub, deep, activeFunction, false, isAccurate, studyPoint);
nerveManager.init(initPower, false, isShowLog); nerveManager.init(initPower, false, isShowLog);
} }

@ -16,7 +16,7 @@ public class LVQ {
private int typeNub;//原型聚类个数,即分类个数(需要模型返回) private int typeNub;//原型聚类个数,即分类个数(需要模型返回)
private MatrixBody[] model;//原型向量(需要模型返回) private MatrixBody[] model;//原型向量(需要模型返回)
private List<MatrixBody> matrixList = new ArrayList<>(); private List<MatrixBody> matrixList = new ArrayList<>();
private double studyPoint = 0.1;//量化学习率 private double studyPoint = 0.0001;//量化学习率
private int length;//向量长度(需要返回) private int length;//向量长度(需要返回)
private boolean isReady = false; private boolean isReady = false;
private int lvqNub; private int lvqNub;
@ -86,8 +86,6 @@ public class LVQ {
long type = matrixBody.getId();//类别 long type = matrixBody.getId();//类别
double distEnd = 0; double distEnd = 0;
int id = 0; int id = 0;
double dis0 = 0;
double dis1 = 1;
for (int i = 0; i < typeNub; i++) { for (int i = 0; i < typeNub; i++) {
MatrixBody modelBody = model[i]; MatrixBody modelBody = model[i];
Matrix modelMatrix = modelBody.getMatrix(); Matrix modelMatrix = modelBody.getMatrix();
@ -97,16 +95,10 @@ public class LVQ {
id = modelBody.getId(); id = modelBody.getId();
distEnd = dist; distEnd = dist;
} }
if (i == 0) {
dis0 = dist;
} else {
dis1 = dist;
}
} }
MatrixBody modelBody = model[id]; MatrixBody modelBody = model[id];
Matrix modelMatrix = modelBody.getMatrix(); Matrix modelMatrix = modelBody.getMatrix();
boolean isRight = id == type; boolean isRight = id == type;
System.out.println("type==" + type + ",dis0==" + dis0 + ",dis1==" + dis1);
Matrix matrix1 = op(matrix, modelMatrix, isRight); Matrix matrix1 = op(matrix, modelMatrix, isRight);
modelBody.setMatrix(matrix1); modelBody.setMatrix(matrix1);
} }

@ -223,11 +223,13 @@ public abstract class Nerve {
private void updateW(double h, long eventId) {//h是学习率 * 当前g梯度 private void updateW(double h, long eventId) {//h是学习率 * 当前g梯度
List<Double> list = features.get(eventId); List<Double> list = features.get(eventId);
double stop = ArithUtil.sub(1, ArithUtil.div(ArithUtil.mul(studyPoint, 0.015), dendrites.size()));
for (Map.Entry<Integer, Double> entry : dendrites.entrySet()) { for (Map.Entry<Integer, Double> entry : dendrites.entrySet()) {
int key = entry.getKey();//上层隐层神经元的编号 int key = entry.getKey();//上层隐层神经元的编号
double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重 double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重
double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入 double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入
double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值
w = ArithUtil.mul(w, stop);
w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重
double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元 double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元
// System.out.println("allG==" + allG + ",dm==" + dm); // System.out.println("allG==" + allG + ",dm==" + dm);

@ -21,7 +21,7 @@ public class FoodTest {
public static void food() throws Exception { public static void food() throws Exception {
Picture picture = new Picture(); Picture picture = new Picture();
TempleConfig templeConfig = new TempleConfig(false, true); TempleConfig templeConfig = new TempleConfig(false, false);
templeConfig.setClassifier(Classifier.DNN); templeConfig.setClassifier(Classifier.DNN);
templeConfig.isShowLog(true); templeConfig.isShowLog(true);
templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4); templeConfig.init(StudyPattern.Accuracy_Pattern, true, 640, 640, 4);
@ -60,7 +60,7 @@ public class FoodTest {
// } // }
// templeConfig.getNormalization().avg(); // templeConfig.getNormalization().avg();
for (int j = 0; j < 1; j++) { for (int j = 0; j < 1; j++) {
for (int i = 1; i < 1900; i++) { for (int i = 1; i < 1500; i++) {
System.out.println("j==" + j + ",study2==================" + i); System.out.println("j==" + j + ",study2==================" + i);
//读取本地URL地址图片,并转化成矩阵 //读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
@ -88,7 +88,7 @@ public class FoodTest {
// Operation operation2 = new Operation(templeConfig2); // Operation operation2 = new Operation(templeConfig2);
int wrong = 0; int wrong = 0;
int allNub = 0; int allNub = 0;
for (int i = 1900; i <= 1998; i++) { for (int i = 1500; i <= 1600; i++) {
//读取本地URL地址图片,并转化成矩阵 //读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");
Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg"); Matrix b = picture.getImageMatrixByLocal("D:\\share\\picture/b" + i + ".jpg");

@ -79,7 +79,7 @@ public class HelloWorld {
} }
templeConfig.getNormalization().avg(); templeConfig.getNormalization().avg();
//三阶段学习 //三阶段学习
for (int i = 1; i < 1900; i++) { for (int i = 1; i < 1000; i++) {
System.out.println("study2==================" + i); System.out.println("study2==================" + i);
//读取本地URL地址图片,并转化成矩阵 //读取本地URL地址图片,并转化成矩阵
Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg"); Matrix a = picture.getImageMatrixByLocal("D:\\share\\picture/a" + i + ".jpg");

Loading…
Cancel
Save