From eb92d8c8ef92781f06b6a01180de5fd4da0f5e37 Mon Sep 17 00:00:00 2001 From: lidapeng Date: Sat, 21 Mar 2020 21:41:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=89=93=E5=8D=B0=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/config/RZ.java | 12 ++++++++ .../wlld/imageRecognition/TempleConfig.java | 22 +++++++++++--- .../org/wlld/nerveCenter/NerveManager.java | 12 ++++++-- .../org/wlld/nerveEntity/HiddenNerve.java | 6 ++-- src/main/java/org/wlld/nerveEntity/Nerve.java | 29 ++++++++++++++++--- .../java/org/wlld/nerveEntity/OutNerve.java | 5 ++-- .../org/wlld/nerveEntity/SensoryNerve.java | 3 +- src/test/java/org/wlld/NerveDemo1.java | 7 +++-- 8 files changed, 78 insertions(+), 18 deletions(-) create mode 100644 src/main/java/org/wlld/config/RZ.java diff --git a/src/main/java/org/wlld/config/RZ.java b/src/main/java/org/wlld/config/RZ.java new file mode 100644 index 0000000..aedabad --- /dev/null +++ b/src/main/java/org/wlld/config/RZ.java @@ -0,0 +1,12 @@ +package org.wlld.config; + +/** + * @author lidapeng + * @description 正则化选项 + * @date 8:55 下午 2020/3/21 + */ +public class RZ { + public static final int NOT_RZ = 0; + public static final int L1 = 1; + public static final int L2 = 2; +} diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index 1134e14..3128fd4 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -3,6 +3,7 @@ package org.wlld.imageRecognition; import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; import org.wlld.config.Classifier; +import org.wlld.config.RZ; import org.wlld.config.StudyPattern; import org.wlld.function.ReLu; import org.wlld.function.Sigmod; @@ -54,6 +55,17 @@ public class TempleConfig { private ActiveFunction activeFunction = new Tanh(); private double studyPoint = 0; private double matrixWidth = 1;//期望矩阵间隔 + private int rzType = RZ.NOT_RZ;//正则化类型,默认不进行正则化 + private double lParam = 0;//正则参数 + private int hiddenNerveNub = 9;//隐层神经元个数 + + public void setRzType(int rzType) { + this.rzType = rzType; + } + + public void setlParam(double lParam) { + this.lParam = lParam; + } public void setMatrixWidth(double matrixWidth) { this.matrixWidth = matrixWidth; @@ -120,7 +132,7 @@ public class TempleConfig { return convolutionNerveManagerB; } - public void finishStudy() throws Exception {//结束 + public void finishStudy() throws Exception {//结束学习 switch (classifier) { case Classifier.LVQ: lvq.start(); @@ -267,8 +279,9 @@ public class TempleConfig { private void initNerveManager(boolean initPower, int sensoryNerveNub , int deep, double studyPoint) throws Exception { - nerveManager = new NerveManager(sensoryNerveNub, 6, - classificationNub, deep, activeFunction, false, isAccurate, studyPoint); + nerveManager = new NerveManager(sensoryNerveNub, hiddenNerveNub, + classificationNub, deep, activeFunction, + false, isAccurate, studyPoint, rzType, lParam); nerveManager.init(initPower, false, isShowLog); } @@ -320,7 +333,8 @@ public class TempleConfig { private NerveManager initNerveManager(Map matrixMap, boolean initPower, int deep) throws Exception { //初始化卷积神经网络 NerveManager convolutionNerveManager = new NerveManager(1, 1, - 1, deep - 1, new ReLu(), true, isAccurate, studyPoint); + 1, deep - 1, new ReLu(), + true, isAccurate, studyPoint, rzType, lParam); convolutionNerveManager.setMatrixMap(matrixMap);//给卷积网络管理器注入期望矩阵 convolutionNerveManager.init(initPower, true, isShowLog); return convolutionNerveManager; diff --git a/src/main/java/org/wlld/nerveCenter/NerveManager.java b/src/main/java/org/wlld/nerveCenter/NerveManager.java index 07eea17..0fbab61 100644 --- a/src/main/java/org/wlld/nerveCenter/NerveManager.java +++ b/src/main/java/org/wlld/nerveCenter/NerveManager.java @@ -31,6 +31,8 @@ public class NerveManager { private boolean isDynamic;//是否是动态神经网络 private List studyList = new ArrayList<>(); private boolean isAccurate;//是否保留精度 + private int rzType;//正则化类型,默认不进行正则化 + private double lParam;//正则参数 public List getStudyList() {//查看每一次的学习率 return studyList; @@ -194,11 +196,13 @@ public class NerveManager { * @param activeFunction 激活函数 * @param isDynamic 是否是动态神经元 * @param isAccurate 是否保留精度 + * @param rzType 正则函数 + * @param lParam 正则系数 * @throws Exception 如果参数错误则抛异常 */ public NerveManager(int sensoryNerveNub, int hiddenNerveNub, int outNerveNub , int hiddenDepth, ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate, - double studyPoint) throws Exception { + double studyPoint, int rzType, double lParam) throws Exception { if (sensoryNerveNub > 0 && hiddenNerveNub > 0 && outNerveNub > 0 && hiddenDepth > 0 && activeFunction != null) { this.hiddenNerveNub = hiddenNerveNub; this.sensoryNerveNub = sensoryNerveNub; @@ -207,6 +211,8 @@ public class NerveManager { this.activeFunction = activeFunction; this.isDynamic = isDynamic; this.isAccurate = isAccurate; + this.rzType = rzType; + this.lParam = lParam; if (studyPoint > 0 && studyPoint < 1) { this.studyPoint = studyPoint; } @@ -259,7 +265,7 @@ public class NerveManager { //初始化输出神经元 for (int i = 1; i < outNerveNub + 1; i++) { OutNerve outNerve = new OutNerve(i, hiddenNerveNub, 0, studyPoint, initPower, - activeFunction, isMatrix, isAccurate, isShowLog); + activeFunction, isMatrix, isAccurate, isShowLog, rzType, lParam); if (isMatrix) {//是卷积层神经网络 outNerve.setMatrixMap(matrixMap); } @@ -306,7 +312,7 @@ public class NerveManager { downNub = hiddenNerveNub; } HiddenNerve hiddenNerve = new HiddenNerve(j, i + 1, upNub, downNub, studyPoint, initPower, activeFunction, isMatrix - , isAccurate); + , isAccurate, rzType, lParam); hiddenNerveList.add(hiddenNerve); } depthNerves.add(hiddenNerveList); diff --git a/src/main/java/org/wlld/nerveEntity/HiddenNerve.java b/src/main/java/org/wlld/nerveEntity/HiddenNerve.java index c6a1bf5..f82d3df 100644 --- a/src/main/java/org/wlld/nerveEntity/HiddenNerve.java +++ b/src/main/java/org/wlld/nerveEntity/HiddenNerve.java @@ -17,8 +17,10 @@ public class HiddenNerve extends Nerve { private int depth;//所处深度 public HiddenNerve(int id, int depth, int upNub, int downNub, double studyPoint, - boolean init, ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate) throws Exception {//隐层神经元 - super(id, upNub, "HiddenNerve", downNub, studyPoint, init, activeFunction, isDynamic, isAccurate); + boolean init, ActiveFunction activeFunction, boolean isDynamic, + boolean isAccurate, int rzType, double lParam) throws Exception {//隐层神经元 + super(id, upNub, "HiddenNerve", downNub, studyPoint, + init, activeFunction, isDynamic, isAccurate, rzType, lParam); this.depth = depth; } diff --git a/src/main/java/org/wlld/nerveEntity/Nerve.java b/src/main/java/org/wlld/nerveEntity/Nerve.java index 644ffd2..a28e49c 100644 --- a/src/main/java/org/wlld/nerveEntity/Nerve.java +++ b/src/main/java/org/wlld/nerveEntity/Nerve.java @@ -3,6 +3,7 @@ package org.wlld.nerveEntity; import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.config.RZ; import org.wlld.i.ActiveFunction; import org.wlld.i.OutBack; import org.wlld.tools.ArithUtil; @@ -35,6 +36,8 @@ public abstract class Nerve { private int backNub = 0;//当前节点被反向传播的次数 protected ActiveFunction activeFunction; private boolean isAccurate = false;//是否保留精度 + private int rzType;//正则化类型,默认不进行正则化 + private double lParam;//正则参数 public Map getDendrites() { return dendrites; @@ -66,7 +69,7 @@ public abstract class Nerve { protected Nerve(int id, int upNub, String name, int downNub, double studyPoint, boolean init, ActiveFunction activeFunction - , boolean isDynamic, boolean isAccurate) throws Exception {//该神经元在同层神经元中的编号 + , boolean isDynamic, boolean isAccurate, int rzType, double lParam) throws Exception {//该神经元在同层神经元中的编号 this.id = id; this.upNub = upNub; this.name = name; @@ -74,8 +77,9 @@ public abstract class Nerve { this.studyPoint = studyPoint; this.activeFunction = activeFunction; this.isAccurate = isAccurate; + this.rzType = rzType; + this.lParam = lParam; initPower(init, isDynamic);//生成随机权重 - } protected void setStudyPoint(double studyPoint) { @@ -221,15 +225,32 @@ public abstract class Nerve { backSendMessage(eventId); } + private double regularization(double w, double param) {//正则化类型 + double re = 0.0; + if (rzType != RZ.NOT_RZ) { + if (rzType == RZ.L2) { + re = ArithUtil.mul(param, -w); + } else if (rzType == RZ.L1) { + if (w > 0) { + re = -param; + } else if (w < 0) { + re = param; + } + } + } + return re; + } + private void updateW(double h, long eventId) {//h是学习率 * 当前g(梯度) List list = features.get(eventId); - double stop = ArithUtil.sub(1, ArithUtil.div(ArithUtil.mul(studyPoint, 0.015), dendrites.size())); + double param = ArithUtil.div(ArithUtil.mul(studyPoint, lParam), dendrites.size()); for (Map.Entry entry : dendrites.entrySet()) { int key = entry.getKey();//上层隐层神经元的编号 double w = entry.getValue();//接收到编号为KEY的上层隐层神经元的权重 double bn = list.get(key - 1);//接收到编号为KEY的上层隐层神经元的输入 double wp = ArithUtil.mul(bn, h);//编号为KEY的上层隐层神经元权重的变化值 - w = ArithUtil.mul(w, stop); + double regular = regularization(w, param);//正则化抑制权重s + w = ArithUtil.add(w, regular); w = ArithUtil.add(w, wp);//修正后的编号为KEY的上层隐层神经元权重 double dm = ArithUtil.mul(w, gradient);//返回给相对应的神经元 // System.out.println("allG==" + allG + ",dm==" + dm); diff --git a/src/main/java/org/wlld/nerveEntity/OutNerve.java b/src/main/java/org/wlld/nerveEntity/OutNerve.java index e8aae9b..93dfb6c 100644 --- a/src/main/java/org/wlld/nerveEntity/OutNerve.java +++ b/src/main/java/org/wlld/nerveEntity/OutNerve.java @@ -23,8 +23,9 @@ public class OutNerve extends Nerve { public OutNerve(int id, int upNub, int downNub, double studyPoint, boolean init, ActiveFunction activeFunction, boolean isDynamic, boolean isAccurate - , boolean isShowLog) throws Exception { - super(id, upNub, "OutNerve", downNub, studyPoint, init, activeFunction, isDynamic, isAccurate); + , boolean isShowLog, int rzType, double lParam) throws Exception { + super(id, upNub, "OutNerve", downNub, studyPoint, init, + activeFunction, isDynamic, isAccurate, rzType, lParam); this.isShowLog = isShowLog; } diff --git a/src/main/java/org/wlld/nerveEntity/SensoryNerve.java b/src/main/java/org/wlld/nerveEntity/SensoryNerve.java index bd5b555..e51bab9 100644 --- a/src/main/java/org/wlld/nerveEntity/SensoryNerve.java +++ b/src/main/java/org/wlld/nerveEntity/SensoryNerve.java @@ -15,7 +15,8 @@ import java.util.Map; public class SensoryNerve extends Nerve { public SensoryNerve(int id, int upNub) throws Exception { - super(id, upNub, "SensoryNerve", 0, 0.1, false, null, false, false); + super(id, upNub, "SensoryNerve", 0, 0.1, false, + null, false, false, 0, 0); } /** diff --git a/src/test/java/org/wlld/NerveDemo1.java b/src/test/java/org/wlld/NerveDemo1.java index 3672665..3b84d0c 100644 --- a/src/test/java/org/wlld/NerveDemo1.java +++ b/src/test/java/org/wlld/NerveDemo1.java @@ -1,6 +1,7 @@ package org.wlld; import org.wlld.MatrixTools.Matrix; +import org.wlld.config.RZ; import org.wlld.function.Sigmod; import org.wlld.i.OutBack; import org.wlld.nerveCenter.NerveManager; @@ -32,7 +33,8 @@ public class NerveDemo1 { * @param activeFunction 激活函数 * @param isDynamic 是否是动态神经元 */ - NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), false, true, 0); + NerveManager nerveManager = new NerveManager(2, 6, 1, 4, new Sigmod(), + false, true, 0, RZ.NOT_RZ, 0); nerveManager.init(true, false, false); @@ -108,7 +110,8 @@ public class NerveDemo1 { public static void test3() throws Exception { NerveManager nerveManager = new NerveManager(3, 6, 3 - , 3, new Sigmod(), false, true, 0); + , 3, new Sigmod(), + false, true, 0, RZ.NOT_RZ, 0); nerveManager.init(true, false, false);//初始化 List> data = new ArrayList<>();//正样本 List> dataB = new ArrayList<>();//负样本