diff --git a/src/main/java/org/wlld/App.java b/src/main/java/org/wlld/App.java index 9808dbc..3c0d8b5 100644 --- a/src/main/java/org/wlld/App.java +++ b/src/main/java/org/wlld/App.java @@ -19,7 +19,7 @@ public class App { //构建一个神经网络管理器,参数:(感知神经元个数,隐层神经元个数,输出神经元个数,输出神经元深度) //一个神经网络管理管理一个神经网络学习内容, NerveManager nerveManager = - new NerveManager(2, 4, 1, 4); + new NerveManager(2, 2, 1, 2); //开始构建神经网络,参数为是否初始化权重及阈值,若 nerveManager.setStudyPoint(0.1);//设置学习率(取值范围是0-1开区间),若不设置默认为0.1 nerveManager.init(true); @@ -33,10 +33,10 @@ public class App { List> testList = new ArrayList<>(); for (int i = 0; i < 1000; i++) { List dm = new ArrayList<>(); - dm.add(1.0); - dm.add(0.8); - dm.add(0.03); - dm.add(0.06); + dm.add(ArithUtil.add(1.0, random.nextDouble())); + dm.add(ArithUtil.add(0.8, random.nextDouble())); + dm.add(ArithUtil.add(-1.0, -random.nextDouble())); + dm.add(ArithUtil.add(-0.8, -random.nextDouble())); testList.add(dm); } for (int i = 0; i < 1000; i++) { @@ -52,9 +52,9 @@ public class App { double hiddenTh = hiddenNerve.getThreshold();//隐层阈值 double outTh = outNerver.getThreshold();//输出阈值 System.out.println("hiddenTh==" + hiddenTh + ",outTh==" + outTh); - sensoryNerves.get(0).postMessage(1, 1.0, false, E1); - sensoryNerves.get(1).postMessage(1, 0.8, false, E1); - sensoryNerves.get(0).postMessage(1, 0.03, false, E2); - sensoryNerves.get(1).postMessage(1, 0.06, false, E2); + sensoryNerves.get(0).postMessage(1, 1.5, false, E1); + sensoryNerves.get(1).postMessage(1, 1.2, false, E1); + sensoryNerves.get(0).postMessage(1, -1.5, false, E2); + sensoryNerves.get(1).postMessage(1, -1.2, false, E2); } } diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index 5e7639b..4b403be 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -55,8 +55,8 @@ public class OutNerve extends Nerve { //上层神经元输入值 * 当前神经元梯度*学习率 =该上层输入的神经元权重变化 //当前梯度神经元梯度变化 *学习旅 * -1 = 当前神经元阈值变化 //ArithUtil.sub(E, outNub) 求这个的累计平均值 - allE = ArithUtil.add(Math.abs(ArithUtil.sub(E, outNub)), allE); - double avg = ArithUtil.div(allE, trainNub); - return ArithUtil.mul(activeFunction.sigmoidG(outNub), avg); + //allE = ArithUtil.add(Math.abs(ArithUtil.sub(E, outNub)), allE); + // double avg = ArithUtil.div(allE, trainNub); + return ArithUtil.mul(activeFunction.sigmoidG(outNub), ArithUtil.sub(E, outNub)); } }