修改梯度下降BUG

pull/1/head
lidapeng 5 years ago
parent 703fad1189
commit c60101d8ff

@ -26,17 +26,18 @@ public class App {
nerveManager.setOutBack(new Test());//添加判断回调输出类
List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves();
Map<Integer, Double> E1 = new HashMap<>();
E1.put(1, 1.0);
E1.put(1, 0.0);
Map<Integer, Double> E2 = new HashMap<>();
E2.put(1, 0.0);
E2.put(1, 1.0);
Random random = new Random();
List<List<Double>> testList = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
double nub = ArithUtil.mul(ArithUtil.div(1, 1000), i);
List<Double> dm = new ArrayList<>();
dm.add(ArithUtil.add(1.0, random.nextDouble()));
dm.add(ArithUtil.add(0.8, random.nextDouble()));
dm.add(ArithUtil.add(-1.0, -random.nextDouble()));
dm.add(ArithUtil.add(-0.8, -random.nextDouble()));
dm.add(ArithUtil.add(0.5, nub));
dm.add(ArithUtil.add(0.5, nub));
dm.add(ArithUtil.add(-0.5, -nub));
dm.add(ArithUtil.add(-0.5, -nub));
testList.add(dm);
}
for (int i = 0; i < 1000; i++) {
@ -47,14 +48,16 @@ public class App {
sensoryNerves.get(1).postMessage(1, ds.get(3), true, E2);
}
Nerve hiddenNerve = nerveManager.getDepthNerves().get(0).get(0);
Nerve outNerver = nerveManager.getOutNevers().get(0);
double hiddenTh = hiddenNerve.getThreshold();//隐层阈值
double outTh = outNerver.getThreshold();//输出阈值
System.out.println("hiddenTh==" + hiddenTh + ",outTh==" + outTh);
sensoryNerves.get(0).postMessage(1, 1.5, false, E1);
sensoryNerves.get(1).postMessage(1, 1.2, false, E1);
sensoryNerves.get(0).postMessage(1, -1.5, false, E2);
sensoryNerves.get(1).postMessage(1, -1.2, false, E2);
System.out.println("outTh==" + outTh);
for (int i = 2; i < 22; i++) {
double nub = ArithUtil.mul(ArithUtil.div(1, 1000), i);
sensoryNerves.get(0).postMessage(1, ArithUtil.add(0.5, nub), false, E1);
sensoryNerves.get(1).postMessage(1, ArithUtil.add(0.5, nub), false, E1);
sensoryNerves.get(0).postMessage(1, ArithUtil.add(-0.5, -nub), false, E2);
sensoryNerves.get(1).postMessage(1, ArithUtil.add(-0.5, -nub), false, E2);
System.out.println("====================================");
}
}
}

@ -106,8 +106,8 @@ public abstract class Nerve {
double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重
double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入
double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值
double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元
w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重
double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元
wg.put(key, dm);//保存上一层权重与梯度的积
dendrites.put(key, w);//保存修正结果
}

Loading…
Cancel
Save