开始做误差反向传播

pull/1/head
lidapeng 5 years ago
parent d3c65f4a97
commit 625ecc2c01

@ -19,7 +19,7 @@ public class App {
nerveManager.init(); nerveManager.init();
List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves(); List<SensoryNerve> sensoryNerves = nerveManager.getSensoryNerves();
for (int i = 0; i < sensoryNerves.size(); i++) { for (int i = 0; i < sensoryNerves.size(); i++) {
sensoryNerves.get(i).postMessage(1, 2 + i); sensoryNerves.get(i).postMessage(1, 2 + i, true);
} }
} }
} }

@ -36,18 +36,22 @@ public class NerveManager {
public void init() {//进行神经网络的初始化构建 public void init() {//进行神经网络的初始化构建
initDepthNerve();//初始化深度隐层神经元 initDepthNerve();//初始化深度隐层神经元
List<Nerve> nerveList = depthNerves.get(0); List<Nerve> nerveList = depthNerves.get(0);//第一层隐层神经元
//最后一层隐层神经元啊
List<Nerve> lastNeveList = depthNerves.get(depthNerves.size() - 1); List<Nerve> lastNeveList = depthNerves.get(depthNerves.size() - 1);
//初始化输出神经元 //初始化输出神经元
List<Nerve> outNevers = new ArrayList<>(); List<Nerve> outNevers = new ArrayList<>();
for (int i = 1; i < outNerveNub + 1; i++) { for (int i = 1; i < outNerveNub + 1; i++) {
OutNerve outNerve = new OutNerve(i, hiddenNerverNub); OutNerve outNerve = new OutNerve(i, hiddenNerverNub);
//输出层神经元连接最后一层隐层神经元
outNerve.connectFathor(lastNeveList);
outNevers.add(outNerve); outNevers.add(outNerve);
} }
//最后一层隐层神经元 与输出神经元进行连接 //最后一层隐层神经元 与输出神经元进行连接
for (Nerve nerve : lastNeveList) { for (Nerve nerve : lastNeveList) {
nerve.connect(outNevers); nerve.connect(outNevers);
} }
//初始化感知神经元 //初始化感知神经元
for (int i = 1; i < sensoryNerveNub + 1; i++) { for (int i = 1; i < sensoryNerveNub + 1; i++) {
SensoryNerve sensoryNerve = new SensoryNerve(i, 0); SensoryNerve sensoryNerve = new SensoryNerve(i, 0);
@ -56,7 +60,6 @@ public class NerveManager {
sensoryNerves.add(sensoryNerve); sensoryNerves.add(sensoryNerve);
} }
} }
private void initDepthNerve() {//初始化隐层神经元1 private void initDepthNerve() {//初始化隐层神经元1
@ -79,11 +82,14 @@ public class NerveManager {
private void initHiddenNerve() {//初始化隐层神经元2 private void initHiddenNerve() {//初始化隐层神经元2
for (int i = 0; i < hiddenDepth - 1; i++) {//遍历深度 for (int i = 0; i < hiddenDepth - 1; i++) {//遍历深度
List<Nerve> hiddenNerveList = depthNerves.get(i); List<Nerve> hiddenNerveList = depthNerves.get(i);//当前遍历隐层神经元
List<Nerve> nextHiddenNerveList = depthNerves.get(i + 1); List<Nerve> nextHiddenNerveList = depthNerves.get(i + 1);//当前遍历的下一层神经元
for (Nerve hiddenNerve : hiddenNerveList) { for (Nerve hiddenNerve : hiddenNerveList) {
hiddenNerve.connect(nextHiddenNerveList); hiddenNerve.connect(nextHiddenNerveList);
} }
for (Nerve nextHiddenNerve : nextHiddenNerveList) {
nextHiddenNerve.connectFathor(hiddenNerveList);
}
} }
} }
} }

@ -19,15 +19,18 @@ public class HiddenNerve extends Nerve {
} }
@Override @Override
public void input(long eventId, double parameter) throws Exception {//接收上一层的输入 public void input(long eventId, double parameter, boolean isStudy) throws Exception {//接收上一层的输入
logger.debug("name:{},myId:{},depth:{},eventId:{},parameter:{}--getInput", name, getId(), depth, eventId, parameter); logger.debug("name:{},myId:{},depth:{},eventId:{},parameter:{}--getInput", name, getId(), depth, eventId, parameter);
boolean allReady = insertParameter(eventId, parameter); boolean allReady = insertParameter(eventId, parameter);
if (allReady) {//参数齐了,开始计算 sigma - threshold if (allReady) {//参数齐了,开始计算 sigma - threshold
logger.debug("depth:{},myID:{}--startCalculation", depth, getId()); logger.debug("depth:{},myID:{}--startCalculation", depth, getId());
double sigma = calculation(eventId); double sigma = calculation(eventId);
double out = activeFunction.sigmoid(sigma); double out = activeFunction.sigmoid(sigma);//激活函数输出数值
if (isStudy) {
outNub = out;
}
logger.debug("depth:{},myID:{},outPut:{}", depth, getId(), out); logger.debug("depth:{},myID:{},outPut:{}", depth, getId(), out);
sendMessage(eventId, out); sendMessage(eventId, out, isStudy);
} }
// sendMessage(); // sendMessage();
} }

@ -13,7 +13,8 @@ import java.util.*;
* @date 9:36 2019/12/21 * @date 9:36 2019/12/21
*/ */
public abstract class Nerve { public abstract class Nerve {
private List<Nerve> axon = new ArrayList<>();//轴突下一层的连接神经元 private List<Nerve> son = new ArrayList<>();//轴突下一层的连接神经元
private List<Nerve> fathor = new ArrayList<>();//树突上一层的连接神经元
private Map<Integer, Double> dendrites = new HashMap<>();//上一层权重 private Map<Integer, Double> dendrites = new HashMap<>();//上一层权重
private int id;//同级神经元编号,注意在同层编号中ID应有唯一性 private int id;//同级神经元编号,注意在同层编号中ID应有唯一性
protected int upNub;//上一层神经元数量 protected int upNub;//上一层神经元数量
@ -21,7 +22,8 @@ public abstract class Nerve {
static final Logger logger = LogManager.getLogger(Nerve.class); static final Logger logger = LogManager.getLogger(Nerve.class);
private double threshold;//此神经元的阈值 private double threshold;//此神经元的阈值
protected ActiveFunction activeFunction = new ActiveFunction(); protected ActiveFunction activeFunction = new ActiveFunction();
protected String name; protected String name;//该神经元所属类型
protected double outNub;//输出数值ps:只有训练模式的时候才可保存输出过的数值)
protected Nerve(int id, int upNub, String name) {//该神经元在同层神经元中的编号 protected Nerve(int id, int upNub, String name) {//该神经元在同层神经元中的编号
this.id = id; this.id = id;
@ -30,17 +32,32 @@ public abstract class Nerve {
initPower();//生成随机权重 initPower();//生成随机权重
} }
public void sendMessage(long enevtId, double parameter) throws Exception { public void sendMessage(long enevtId, double parameter, boolean isStudy) throws Exception {
if (axon.size() > 0) { if (son.size() > 0) {
for (Nerve nerve : axon) { for (Nerve nerve : son) {
nerve.input(enevtId, parameter); nerve.input(enevtId, parameter, isStudy);
} }
} else { } else {
throw new Exception("this layer is lastIndex"); throw new Exception("this layer is lastIndex");
} }
} }
protected void input(long eventId, double parameter) throws Exception {//输入 public void backSendMessage(double parameter) throws Exception {//反向传播
if (fathor.size() > 0) {
for (Nerve nerve : fathor) {
nerve.backGetMessage(parameter);
}
} else {
throw new Exception("this layer is firstIndex");
}
}
protected void input(long eventId, double parameter, boolean isStudy) throws Exception {//输入
}
private void backGetMessage(double parameter) {//反向传播
} }
protected boolean insertParameter(long eventId, double parameter) {//添加参数 protected boolean insertParameter(long eventId, double parameter) {//添加参数
@ -93,6 +110,10 @@ public abstract class Nerve {
public void connect(List<Nerve> nerveList) { public void connect(List<Nerve> nerveList) {
axon.addAll(nerveList); son.addAll(nerveList);//连接下一层
}
public void connectFathor(List<Nerve> nerveList) {
fathor.addAll(nerveList);//连接上一层
} }
} }

@ -16,12 +16,15 @@ public class OutNerve extends Nerve {
} }
@Override @Override
public void input(long eventId, double parameter) { public void input(long eventId, double parameter, boolean isStudy) {
logger.debug("Nerve:{},eventId:{},parameter:{}--getInput", name, eventId, parameter); logger.debug("Nerve:{},eventId:{},parameter:{}--getInput", name, eventId, parameter);
boolean allReady = insertParameter(eventId, parameter); boolean allReady = insertParameter(eventId, parameter);
if (allReady) {//参数齐了,开始计算 sigma - threshold if (allReady) {//参数齐了,开始计算 sigma - threshold
double sigma = calculation(eventId); double sigma = calculation(eventId);
double out = activeFunction.sigmoid(sigma); double out = activeFunction.sigmoid(sigma);
if (isStudy) {
outNub = out;
}
logger.debug("myId:{},outPut:{}------END", getId(), out); logger.debug("myId:{},outPut:{}------END", getId(), out);
} }
} }

@ -13,8 +13,8 @@ public class SensoryNerve extends Nerve {
super(id, upNub, "SensoryNerve"); super(id, upNub, "SensoryNerve");
} }
public void postMessage(long eventId, double parameter) throws Exception {//感知神经元输出 public void postMessage(long eventId, double parameter, boolean isStudy) throws Exception {//感知神经元输出
sendMessage(eventId, parameter); sendMessage(eventId, parameter, isStudy);
} }
@Override @Override

Loading…
Cancel
Save