From 625ecc2c018a66999f4aef110e75c430347a166f Mon Sep 17 00:00:00 2001 From: lidapeng Date: Mon, 23 Dec 2019 10:48:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BC=80=E5=A7=8B=E5=81=9A=E8=AF=AF=E5=B7=AE?= =?UTF-8?q?=E5=8F=8D=E5=90=91=E4=BC=A0=E6=92=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/App.java | 2 +- .../org/wlld/nerveCenter/NerveManager.java | 14 +++++-- .../org/wlld/nerveEntity/HiddenNerve.java | 9 +++-- src/main/java/org/wlld/nerveEntity/Nerve.java | 37 +++++++++++++++---- .../java/org/wlld/nerveEntity/OutNerve.java | 5 ++- .../org/wlld/nerveEntity/SensoryNerve.java | 4 +- 6 files changed, 52 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/wlld/App.java b/src/main/java/org/wlld/App.java index 55e5eff..e570d7f 100644 --- a/src/main/java/org/wlld/App.java +++ b/src/main/java/org/wlld/App.java @@ -19,7 +19,7 @@ public class App { nerveManager.init(); List sensoryNerves = nerveManager.getSensoryNerves(); for (int i = 0; i < sensoryNerves.size(); i++) { - sensoryNerves.get(i).postMessage(1, 2 + i); + sensoryNerves.get(i).postMessage(1, 2 + i, true); } } } diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java index ed2fde3..ae17c7d 100644 --- a/src/main/java/org/wlld/nerveCenter/NerveManager.java +++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java @@ -36,18 +36,22 @@ public class NerveManager { public void init() {//进行神经网络的初始化构建 initDepthNerve();//初始化深度隐层神经元 - List nerveList = depthNerves.get(0); + List nerveList = depthNerves.get(0);//第一层隐层神经元 + //最后一层隐层神经元啊 List lastNeveList = depthNerves.get(depthNerves.size() - 1); //初始化输出神经元 List outNevers = new ArrayList<>(); for (int i = 1; i < outNerveNub + 1; i++) { OutNerve outNerve = new OutNerve(i, hiddenNerverNub); + //输出层神经元连接最后一层隐层神经元 + outNerve.connectFathor(lastNeveList); outNevers.add(outNerve); } //最后一层隐层神经元 与输出神经元进行连接 for (Nerve nerve : lastNeveList) { nerve.connect(outNevers); } + //初始化感知神经元 for (int i = 1; i < sensoryNerveNub + 1; i++) { SensoryNerve sensoryNerve = new SensoryNerve(i, 0); @@ -56,7 +60,6 @@ public class NerveManager { sensoryNerves.add(sensoryNerve); } - } private void initDepthNerve() {//初始化隐层神经元1 @@ -79,11 +82,14 @@ public class NerveManager { private void initHiddenNerve() {//初始化隐层神经元2 for (int i = 0; i < hiddenDepth - 1; i++) {//遍历深度 - List hiddenNerveList = depthNerves.get(i); - List nextHiddenNerveList = depthNerves.get(i + 1); + List hiddenNerveList = depthNerves.get(i);//当前遍历隐层神经元 + List nextHiddenNerveList = depthNerves.get(i + 1);//当前遍历的下一层神经元 for (Nerve hiddenNerve : hiddenNerveList) { hiddenNerve.connect(nextHiddenNerveList); } + for (Nerve nextHiddenNerve : nextHiddenNerveList) { + nextHiddenNerve.connectFathor(hiddenNerveList); + } } } } diff --git a/src/main/java/org/wlld/nerveEntity/HiddenNerve.java b/src/main/java/org/wlld/nerveEntity/HiddenNerve.java index abc2e9a..5851179 100644 --- a/src/main/java/org/wlld/nerveEntity/HiddenNerve.java +++ b/src/main/java/org/wlld/nerveEntity/HiddenNerve.java @@ -19,15 +19,18 @@ public class HiddenNerve extends Nerve { } @Override - public void input(long eventId, double parameter) throws Exception {//接收上一层的输入 + public void input(long eventId, double parameter, boolean isStudy) throws Exception {//接收上一层的输入 logger.debug("name:{},myId:{},depth:{},eventId:{},parameter:{}--getInput", name, getId(), depth, eventId, parameter); boolean allReady = insertParameter(eventId, parameter); if (allReady) {//参数齐了,开始计算 sigma - threshold logger.debug("depth:{},myID:{}--startCalculation", depth, getId()); double sigma = calculation(eventId); - double out = activeFunction.sigmoid(sigma); + double out = activeFunction.sigmoid(sigma);//激活函数输出数值 + if (isStudy) { + outNub = out; + } logger.debug("depth:{},myID:{},outPut:{}", depth, getId(), out); - sendMessage(eventId, out); + sendMessage(eventId, out, isStudy); } // sendMessage(); } diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java index 95152b1..0bcc962 100644 --- a/src/main/java/org/wlld/nerveEntity/Nerve.java +++ b/src/main/java/org/wlld/nerveEntity/Nerve.java @@ -13,7 +13,8 @@ import java.util.*; * @date 9:36 上午 2019/12/21 */ public abstract class Nerve { - private List axon = new ArrayList<>();//轴突下一层的连接神经元 + private List son = new ArrayList<>();//轴突下一层的连接神经元 + private List fathor = new ArrayList<>();//树突上一层的连接神经元 private Map dendrites = new HashMap<>();//上一层权重 private int id;//同级神经元编号,注意在同层编号中ID应有唯一性 protected int upNub;//上一层神经元数量 @@ -21,7 +22,8 @@ public abstract class Nerve { static final Logger logger = LogManager.getLogger(Nerve.class); private double threshold;//此神经元的阈值 protected ActiveFunction activeFunction = new ActiveFunction(); - protected String name; + protected String name;//该神经元所属类型 + protected double outNub;//输出数值(ps:只有训练模式的时候才可保存输出过的数值) protected Nerve(int id, int upNub, String name) {//该神经元在同层神经元中的编号 this.id = id; @@ -30,17 +32,32 @@ public abstract class Nerve { initPower();//生成随机权重 } - public void sendMessage(long enevtId, double parameter) throws Exception { - if (axon.size() > 0) { - for (Nerve nerve : axon) { - nerve.input(enevtId, parameter); + public void sendMessage(long enevtId, double parameter, boolean isStudy) throws Exception { + if (son.size() > 0) { + for (Nerve nerve : son) { + nerve.input(enevtId, parameter, isStudy); } } else { throw new Exception("this layer is lastIndex"); } } - protected void input(long eventId, double parameter) throws Exception {//输入 + public void backSendMessage(double parameter) throws Exception {//反向传播 + if (fathor.size() > 0) { + for (Nerve nerve : fathor) { + nerve.backGetMessage(parameter); + } + } else { + throw new Exception("this layer is firstIndex"); + } + } + + protected void input(long eventId, double parameter, boolean isStudy) throws Exception {//输入 + + } + + private void backGetMessage(double parameter) {//反向传播 + } protected boolean insertParameter(long eventId, double parameter) {//添加参数 @@ -93,6 +110,10 @@ public abstract class Nerve { public void connect(List nerveList) { - axon.addAll(nerveList); + son.addAll(nerveList);//连接下一层 + } + + public void connectFathor(List nerveList) { + fathor.addAll(nerveList);//连接上一层 } } diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index c2bf3f7..9b6e1df 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -16,12 +16,15 @@ public class OutNerve extends Nerve { } @Override - public void input(long eventId, double parameter) { + public void input(long eventId, double parameter, boolean isStudy) { logger.debug("Nerve:{},eventId:{},parameter:{}--getInput", name, eventId, parameter); boolean allReady = insertParameter(eventId, parameter); if (allReady) {//参数齐了,开始计算 sigma - threshold double sigma = calculation(eventId); double out = activeFunction.sigmoid(sigma); + if (isStudy) { + outNub = out; + } logger.debug("myId:{},outPut:{}------END", getId(), out); } } diff --git a/src/main/java/org/wlld/nerveEntity/SensoryNerve.java b/src/main/java/org/wlld/nerveEntity/SensoryNerve.java index 18aa6c9..32bdc51 100644 --- a/src/main/java/org/wlld/nerveEntity/SensoryNerve.java +++ b/src/main/java/org/wlld/nerveEntity/SensoryNerve.java @@ -13,8 +13,8 @@ public class SensoryNerve extends Nerve { super(id, upNub, "SensoryNerve"); } - public void postMessage(long eventId, double parameter) throws Exception {//感知神经元输出 - sendMessage(eventId, parameter); + public void postMessage(long eventId, double parameter, boolean isStudy) throws Exception {//感知神经元输出 + sendMessage(eventId, parameter, isStudy); } @Override