From 7c4fb8f2d7d583cbfe258042faca22093e223fd9 Mon Sep 17 00:00:00 2001 From: lidapeng Date: Mon, 30 Dec 2019 22:28:31 +0800 Subject: [PATCH] =?UTF-8?q?=E7=B4=AF=E8=AE=A1BP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/App.java | 47 ++++++++++++++----- .../org/wlld/nerveEntity/HiddenNerve.java | 1 + src/main/java/org/wlld/nerveEntity/Nerve.java | 8 ++-- .../java/org/wlld/nerveEntity/OutNerve.java | 9 +++- .../org/wlld/nerveEntity/SensoryNerve.java | 1 + 5 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/main/java/org/wlld/App.java b/src/main/java/org/wlld/App.java index bdf23eb..9808dbc 100644 --- a/src/main/java/org/wlld/App.java +++ b/src/main/java/org/wlld/App.java @@ -1,11 +1,11 @@ package org.wlld; import org.wlld.nerveCenter.NerveManager; +import org.wlld.nerveEntity.Nerve; import org.wlld.nerveEntity.SensoryNerve; +import org.wlld.tools.ArithUtil; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; /** * 测试入口类! @@ -19,21 +19,42 @@ public class App { //构建一个神经网络管理器,参数:(感知神经元个数,隐层神经元个数,输出神经元个数,输出神经元深度) //一个神经网络管理管理一个神经网络学习内容, NerveManager nerveManager = - new NerveManager(2, 3, 1, 2); + new NerveManager(2, 4, 1, 4); //开始构建神经网络,参数为是否初始化权重及阈值,若 + nerveManager.setStudyPoint(0.1);//设置学习率(取值范围是0-1开区间),若不设置默认为0.1 nerveManager.init(true); - nerveManager.setStudyPoint(0.2);//设置学习率(取值范围是0-1开区间),若不设置默认为0.1 nerveManager.setOutBack(new Test());//添加判断回调输出类 List sensoryNerves = nerveManager.getSensoryNerves(); - Map E = new HashMap<>(); - for (int i = 0; i < nerveManager.getOutNerveNub(); i++) { - E.put(i + 1, 1.0); + Map E1 = new HashMap<>(); + E1.put(1, 1.0); + Map E2 = new HashMap<>(); + E2.put(1, 0.0); + Random random = new Random(); + List> testList = new ArrayList<>(); + for (int i = 0; i < 1000; i++) { + List dm = new ArrayList<>(); + dm.add(1.0); + dm.add(0.8); + dm.add(0.03); + dm.add(0.06); + testList.add(dm); } - for (int i = 0; i < sensoryNerves.size(); i++) { - sensoryNerves.get(i).postMessage(1, 2 + i, true, E); - } - for (int i = 0; i < sensoryNerves.size(); i++) { - sensoryNerves.get(i).postMessage(1, 2 + i, false, E); + for (int i = 0; i < 1000; i++) { + List ds = testList.get(i); + sensoryNerves.get(0).postMessage(1, ds.get(0), true, E1); + sensoryNerves.get(1).postMessage(1, ds.get(1), true, E1); + sensoryNerves.get(0).postMessage(1, ds.get(2), true, E2); + sensoryNerves.get(1).postMessage(1, ds.get(3), true, E2); } + + Nerve hiddenNerve = nerveManager.getDepthNerves().get(0).get(0); + Nerve outNerver = nerveManager.getOutNevers().get(0); + double hiddenTh = hiddenNerve.getThreshold();//隐层阈值 + double outTh = outNerver.getThreshold();//输出阈值 + System.out.println("hiddenTh==" + hiddenTh + ",outTh==" + outTh); + sensoryNerves.get(0).postMessage(1, 1.0, false, E1); + sensoryNerves.get(1).postMessage(1, 0.8, false, E1); + sensoryNerves.get(0).postMessage(1, 0.03, false, E2); + sensoryNerves.get(1).postMessage(1, 0.06, false, E2); } } diff --git a/src/main/java/org/wlld/nerveEntity/HiddenNerve.java b/src/main/java/org/wlld/nerveEntity/HiddenNerve.java index 890643c..fc94c06 100644 --- a/src/main/java/org/wlld/nerveEntity/HiddenNerve.java +++ b/src/main/java/org/wlld/nerveEntity/HiddenNerve.java @@ -27,6 +27,7 @@ public class HiddenNerve extends Nerve { if (isStudy) { outNub = out; } else { + //System.out.println("sigma:" + sigma); destoryParameter(eventId); } // logger.debug("depth:{},myID:{},outPut:{}", depth, getId(), out); diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java index 0fd6e28..70cceb3 100644 --- a/src/main/java/org/wlld/nerveEntity/Nerve.java +++ b/src/main/java/org/wlld/nerveEntity/Nerve.java @@ -93,7 +93,7 @@ public abstract class Nerve { protected void updatePower(long eventId) throws Exception {//修改阈值 double h = ArithUtil.mul(gradient, studyPoint);//梯度下降 - threshold = ArithUtil.sub(threshold, h);//更新阈值 + threshold = ArithUtil.add(threshold, -h);//更新阈值 updateW(h, eventId); sigmaW = 0;//求和结果归零 backSendMessage(eventId); @@ -106,8 +106,8 @@ public abstract class Nerve { double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重 double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入 double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 - w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元 + w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 wg.put(key, dm);//保存上一层权重与梯度的积 dendrites.put(key, w);//保存修正结果 } @@ -140,11 +140,13 @@ public abstract class Nerve { for (int i = 0; i < featuresList.size(); i++) { double value = featuresList.get(i); double w = dendrites.get(i + 1); + //System.out.println("w==" + w + ",value==" + value); sigma = ArithUtil.add(ArithUtil.mul(w, value), sigma); //logger.debug("name:{},eventId:{},id:{},myId:{},w:{},value:{}", name, eventId, i + 1, id, w, value); } + //System.out.println("结束===========" + sigma); //logger.debug("当前神经元线性变化已经完成,name:{},id:{}", name, getId()); - return ArithUtil.sub(sigma, threshold); + return ArithUtil.add(sigma, threshold); } private void initPower() {//初始化权重及阈值 diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index 5cc7d96..5e7639b 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -13,6 +13,8 @@ import java.util.Map; public class OutNerve extends Nerve { // static final Logger logger = LogManager.getLogger(OutNerve.class); private OutBack outBack; + private long trainNub = 0;//训练次数 + private double allE;//训练累计EK public OutNerve(int id, int upNub, int downNub, double studyPoint, boolean init) { super(id, upNub, "OutNerve", downNub, studyPoint, init); @@ -31,12 +33,14 @@ public class OutNerve extends Nerve { double out = activeFunction.sigmoid(sigma); // logger.debug("myId:{},outPut:{}------END", getId(), out); if (isStudy) {//输出结果并进行BP调整权重及阈值 + trainNub++;//训练次数增加 outNub = out; this.E = E.get(getId()); gradient = outGradient();//当前梯度变化 //调整权重 修改阈值 并进行反向传播 updatePower(eventId); } else {//获取最后输出 + //System.out.println("当前阈值" + threshold); destoryParameter(eventId); if (outBack != null) { outBack.getBack(out, getId(), eventId); @@ -50,6 +54,9 @@ public class OutNerve extends Nerve { private double outGradient() {//生成输出层神经元梯度变化 //上层神经元输入值 * 当前神经元梯度*学习率 =该上层输入的神经元权重变化 //当前梯度神经元梯度变化 *学习旅 * -1 = 当前神经元阈值变化 - return ArithUtil.mul(activeFunction.sigmoidG(outNub), ArithUtil.sub(E, outNub)); + //ArithUtil.sub(E, outNub) 求这个的累计平均值 + allE = ArithUtil.add(Math.abs(ArithUtil.sub(E, outNub)), allE); + double avg = ArithUtil.div(allE, trainNub); + return ArithUtil.mul(activeFunction.sigmoidG(outNub), avg); } } diff --git a/src/main/java/org/wlld/nerveEntity/SensoryNerve.java b/src/main/java/org/wlld/nerveEntity/SensoryNerve.java index 49b260c..3104e5b 100644 --- a/src/main/java/org/wlld/nerveEntity/SensoryNerve.java +++ b/src/main/java/org/wlld/nerveEntity/SensoryNerve.java @@ -15,6 +15,7 @@ public class SensoryNerve extends Nerve { } public void postMessage(long eventId, double parameter, boolean isStudy, Map E) throws Exception {//感知神经元输出 + sendMessage(eventId, parameter, isStudy, E); }