修改梯度下降BUG

pull/1/head
lidapeng 5 years ago
parent 703fad1189
commit c60101d8ff

@ -26,17 +26,18 @@ public class App {
nerveManager.setOutBack(new Test());//添加判断回调输出类 nerveManager.setOutBack(new Test());//添加判断回调输出类
List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves(); List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves();
Map<Integer, Double> E1 = new HashMap<>(); Map<Integer, Double> E1 = new HashMap<>();
E1.put(1, 1.0); E1.put(1, 0.0);
Map<Integer, Double> E2 = new HashMap<>(); Map<Integer, Double> E2 = new HashMap<>();
E2.put(1, 0.0); E2.put(1, 1.0);
Random random = new Random(); Random random = new Random();
List<List<Double>> testList = new ArrayList<>(); List<List<Double>> testList = new ArrayList<>();
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
double nub = ArithUtil.mul(ArithUtil.div(1, 1000), i);
List<Double> dm = new ArrayList<>(); List<Double> dm = new ArrayList<>();
dm.add(ArithUtil.add(1.0, random.nextDouble())); dm.add(ArithUtil.add(0.5, nub));
dm.add(ArithUtil.add(0.8, random.nextDouble())); dm.add(ArithUtil.add(0.5, nub));
dm.add(ArithUtil.add(-1.0, -random.nextDouble())); dm.add(ArithUtil.add(-0.5, -nub));
dm.add(ArithUtil.add(-0.8, -random.nextDouble())); dm.add(ArithUtil.add(-0.5, -nub));
testList.add(dm); testList.add(dm);
} }
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
@ -47,14 +48,16 @@ public class App {
sensoryNerves.get(1).postMessage(1, ds.get(3), true, E2); sensoryNerves.get(1).postMessage(1, ds.get(3), true, E2);
} }
Nerve hiddenNerve = nerveManager.getDepthNerves().get(0).get(0);
Nerve outNerver = nerveManager.getOutNevers().get(0); Nerve outNerver = nerveManager.getOutNevers().get(0);
double hiddenTh = hiddenNerve.getThreshold();//隐层阈值
double outTh = outNerver.getThreshold();//输出阈值 double outTh = outNerver.getThreshold();//输出阈值
System.out.println("hiddenTh==" + hiddenTh + ",outTh==" + outTh); System.out.println("outTh==" + outTh);
sensoryNerves.get(0).postMessage(1, 1.5, false, E1); for (int i = 2; i < 22; i++) {
sensoryNerves.get(1).postMessage(1, 1.2, false, E1); double nub = ArithUtil.mul(ArithUtil.div(1, 1000), i);
sensoryNerves.get(0).postMessage(1, -1.5, false, E2); sensoryNerves.get(0).postMessage(1, ArithUtil.add(0.5, nub), false, E1);
sensoryNerves.get(1).postMessage(1, -1.2, false, E2); sensoryNerves.get(1).postMessage(1, ArithUtil.add(0.5, nub), false, E1);
sensoryNerves.get(0).postMessage(1, ArithUtil.add(-0.5, -nub), false, E2);
sensoryNerves.get(1).postMessage(1, ArithUtil.add(-0.5, -nub), false, E2);
System.out.println("====================================");
}
} }
} }

@ -106,8 +106,8 @@ public abstract class Nerve {
double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重 double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重
double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入 double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入
double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值
double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元
w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重
double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元
wg.put(key, dm);//保存上一层权重与梯度的积 wg.put(key, dm);//保存上一层权重与梯度的积
dendrites.put(key, w);//保存修正结果 dendrites.put(key, w);//保存修正结果
} }

Loading…
Cancel
Save