From df70d9b20f56f4453c7c7d01312d638d6c67c5e9 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Mon, 7 Sep 2020 11:05:40 +0800 Subject: [PATCH 01/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../segmentation/Watershed.java | 14 +- .../org/wlld/regressionForest/Forest.java | 140 ++++++++++++++++++ .../regressionForest/RegressionForest.java | 56 +++++++ src/main/java/org/wlld/tools/Frequency.java | 19 +++ src/test/java/coverTest/FoodTest.java | 17 +-- 5 files changed, 227 insertions(+), 19 deletions(-) create mode 100644 src/main/java/org/wlld/regressionForest/Forest.java create mode 100644 src/main/java/org/wlld/regressionForest/RegressionForest.java diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java index 56958cb..6612b9f 100644 --- a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java +++ b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java @@ -277,13 +277,13 @@ public class Watershed { regionBodies.add(regionBody); } } -// for (RegionBody regionBody : regionBodies) { -// int minX = regionBody.getMinX(); -// int maxX = regionBody.getMaxX(); -// int minY = regionBody.getMinY(); -// int maxY = regionBody.getMaxY(); -// System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); -// } + for (RegionBody regionBody : regionBodies) { + int minX = regionBody.getMinX(); + int maxX = regionBody.getMaxX(); + int minY = regionBody.getMinY(); + int maxY = regionBody.getMaxY(); + System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); + } return iou(regionBodies); } diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java new file mode 100644 index 0000000..c5accd5 --- /dev/null +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -0,0 +1,140 @@ +package org.wlld.regressionForest; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.tools.Frequency; + +import java.util.Arrays; + + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 分段切割容器 + */ +public class Forest extends Frequency { + private Matrix conditionMatrix;//条件矩阵 + private Matrix resultMatrix;//结果矩阵 + private Forest forestLeft;//左森林 + private Forest forestRight;//右森林 + private int size; + private double min;//下限 + private double max;//上限 + private double resultVariance;//结果矩阵方差 + private double median;//结果矩阵中位数 + private double shrinkParameter;//方差收缩参数 + private double[] w; + + public Forest(int size, double shrinkParameter) { + this.size = size; + this.shrinkParameter = shrinkParameter; + } + + public double getResultVariance() { + return resultVariance; + } + + public void setResultVariance(double resultVariance) { + this.resultVariance = resultVariance; + } + + public void cut() throws Exception { + int y = resultMatrix.getY(); + if (y > 4) { + double[] dm = new double[y]; + for (int i = 0; i < y; i++) { + dm[i] = resultMatrix.getNumber(i, 0); + } + Arrays.sort(dm);//排序 + int z = y / 2; + median = dm[z]; + forestLeft = new Forest(size, shrinkParameter); + forestRight = new Forest(size, shrinkParameter); + Matrix conditionMatrixLeft = new Matrix(z, size);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(y - z, size);//条件矩阵右 + Matrix resultMatrixLeft = new Matrix(z, 1);//结果矩阵左 + Matrix resultMatrixRight = new Matrix(y - z, 1);//结果矩阵右 + forestLeft.setConditionMatrix(conditionMatrixLeft); + forestLeft.setResultMatrix(resultMatrixLeft); + forestRight.setConditionMatrix(conditionMatrixRight); + forestRight.setConditionMatrix(resultMatrixRight); + int leftIndex = 0;//左矩阵添加行数 + int rightIndex = 0;//右矩阵添加行数 + double[] resultLeft = new double[z]; + double[] resultRight = new double[y - z]; + for (int i = 0; i < y; i++) { + double nub = resultMatrix.getNumber(i, 0);//结果矩阵 + if (nub > median) {//进入右森林并计算右森林结果矩阵方差 + for (int j = 0; j < size; j++) {//进入右森林的条件矩阵 + conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); + } + resultRight[rightIndex] = nub; + resultMatrixRight.setNub(rightIndex, 0, nub); + rightIndex++; + } else {//进入左森林并计算左森林结果矩阵方差 + for (int j = 0; j < size; j++) {//进入右森林的条件矩阵 + conditionMatrixLeft.setNub(leftIndex, j, conditionMatrix.getNumber(i, j)); + } + resultLeft[leftIndex] = nub; + resultMatrixLeft.setNub(leftIndex, 0, nub); + leftIndex++; + } + } + //分区完成,计算两棵树结果矩阵的方差 + double leftVar = variance(resultLeft); + double rightVar = variance(resultRight); + double variance = resultVariance * shrinkParameter; + if (leftVar < variance || rightVar < variance) {//继续拆分 + double[] left = getLimit(resultLeft); + double[] right = getLimit(resultRight); + forestLeft.setMin(left[0]); + forestLeft.setMax(left[1]); + forestRight.setMin(right[0]); + forestRight.setMax(right[1]); + } else {//不继续拆分 + forestLeft = null; + forestRight = null; + } + } + } + + public double getMin() { + return min; + } + + public void setMin(double min) { + this.min = min; + } + + public double getMax() { + return max; + } + + public void setMax(double max) { + this.max = max; + } + + public Matrix getConditionMatrix() { + return conditionMatrix; + } + + public void setConditionMatrix(Matrix conditionMatrix) { + this.conditionMatrix = conditionMatrix; + } + + public Matrix getResultMatrix() { + return resultMatrix; + } + + public void setResultMatrix(Matrix resultMatrix) { + this.resultMatrix = resultMatrix; + } + + public double[] getW() { + return w; + } + + public void setW(double[] w) { + this.w = w; + } +} diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java new file mode 100644 index 0000000..d3690e0 --- /dev/null +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -0,0 +1,56 @@ +package org.wlld.regressionForest; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 回归森林 + */ +public class RegressionForest { + private double[] w; + private Matrix conditionMatrix;//条件矩阵 + private Matrix resultMatrix;//结果矩阵 + private int featureNub;//特征数量 + private int xIndex = 0;//记录插入位置 + + public RegressionForest(int size, int featureNub) throws Exception {//初始化 + if (size > 0 && featureNub > 0) { + this.featureNub = featureNub; + w = new double[size]; + conditionMatrix = new Matrix(size, featureNub); + resultMatrix = new Matrix(size, 1); + } else { + throw new Exception("size and featureNub too small"); + } + } + + public void insertFeature(double[] feature, double result) throws Exception {//插入数据 + if (feature.length == featureNub) { + for (int i = 0; i < featureNub; i++) { + if (i < featureNub - 1) { + conditionMatrix.setNub(xIndex, i, feature[i]); + } else { + conditionMatrix.setNub(xIndex, i, 1.0); + resultMatrix.setNub(xIndex, 0, result); + } + } + xIndex++; + } else { + throw new Exception("feature length is not equals"); + } + } + + public void regression() throws Exception {//开始进行回归 + if (xIndex > 0) { + Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); + for (int i = 0; i < ws.getX(); i++) { + w[i] = ws.getNumber(i, 0); + } + } else { + throw new Exception("regression matrix size is zero"); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/wlld/tools/Frequency.java b/src/main/java/org/wlld/tools/Frequency.java index d7bb2cd..8bddcd5 100644 --- a/src/main/java/org/wlld/tools/Frequency.java +++ b/src/main/java/org/wlld/tools/Frequency.java @@ -99,4 +99,23 @@ public abstract class Frequency {//统计频数 } return ArithUtil.div(my, all); } + + public double[] getLimit(double[] m) {//获取数组中的最大值和最小值,最小值在前,最大值在后 + double[] limit = new double[2]; + double max = 0; + double min = -1; + int l = m.length; + for (int i = 0; i < l; i++) { + double nub = m[i]; + if (min == -1 || nub < min) { + min = nub; + } + if (nub > max) { + max = nub; + } + } + limit[0] = min; + limit[1] = max; + return limit; + } } diff --git a/src/test/java/coverTest/FoodTest.java b/src/test/java/coverTest/FoodTest.java index cf2ea69..b0ea1d6 100644 --- a/src/test/java/coverTest/FoodTest.java +++ b/src/test/java/coverTest/FoodTest.java @@ -61,7 +61,7 @@ public class FoodTest { Food food = templeConfig.getFood(); // cutting.setMaxRain(360);//切割阈值 - cutting.setTh(0.3); + cutting.setTh(0.6); cutting.setRegionNub(200); cutting.setMaxIou(2.0); //knn参数 @@ -73,8 +73,8 @@ public class FoodTest { //菜品识别实体类 food.setShrink(20);//缩紧像素 food.setTimes(2);//聚类数据增强 - food.setRowMark(0.1);//0.12 - food.setColumnMark(0.1);//0.25 + food.setRowMark(0.12);//0.12 + food.setColumnMark(0.12);//0.25 food.setRegressionNub(20000); food.setTrayTh(0.08); templeConfig.setClassifier(Classifier.KNN); @@ -99,18 +99,11 @@ public class FoodTest { ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg"); operation.setTray(threeChannelMatrix); for (int i = 1; i <= 1; i++) { - ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/a1.jpg"); - ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/b.jpg"); - ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/c.jpg"); + ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/test.jpg"); operation.colorStudy(threeChannelMatrix1, 1, specificationsList); - operation.colorStudy(threeChannelMatrix2, 2, specificationsList); - operation.colorStudy(threeChannelMatrix3, 3, specificationsList); } - -// minX==301,minY==430,maxX==854,maxY==920 -// minX==497,minY==1090,maxX==994,maxY==1520 - test2(templeConfig); + // test2(templeConfig); } public static void study() throws Exception { From d17b58c3c11556743afde1aa0d8604d753df5630 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Mon, 7 Sep 2020 17:22:21 +0800 Subject: [PATCH 02/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 27 +++++--- .../regressionForest/RegressionForest.java | 69 +++++++++++++++++-- 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index c5accd5..431394d 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -17,7 +17,7 @@ public class Forest extends Frequency { private Matrix resultMatrix;//结果矩阵 private Forest forestLeft;//左森林 private Forest forestRight;//右森林 - private int size; + private int featureSize; private double min;//下限 private double max;//上限 private double resultVariance;//结果矩阵方差 @@ -25,9 +25,10 @@ public class Forest extends Frequency { private double shrinkParameter;//方差收缩参数 private double[] w; - public Forest(int size, double shrinkParameter) { - this.size = size; + public Forest(int featureSize, double shrinkParameter) { + this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; + w = new double[featureSize]; } public double getResultVariance() { @@ -48,10 +49,10 @@ public class Forest extends Frequency { Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; - forestLeft = new Forest(size, shrinkParameter); - forestRight = new Forest(size, shrinkParameter); - Matrix conditionMatrixLeft = new Matrix(z, size);//条件矩阵左 - Matrix conditionMatrixRight = new Matrix(y - z, size);//条件矩阵右 + forestLeft = new Forest(featureSize, shrinkParameter); + forestRight = new Forest(featureSize, shrinkParameter); + Matrix conditionMatrixLeft = new Matrix(z, featureSize);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(y - z, featureSize);//条件矩阵右 Matrix resultMatrixLeft = new Matrix(z, 1);//结果矩阵左 Matrix resultMatrixRight = new Matrix(y - z, 1);//结果矩阵右 forestLeft.setConditionMatrix(conditionMatrixLeft); @@ -65,14 +66,14 @@ public class Forest extends Frequency { for (int i = 0; i < y; i++) { double nub = resultMatrix.getNumber(i, 0);//结果矩阵 if (nub > median) {//进入右森林并计算右森林结果矩阵方差 - for (int j = 0; j < size; j++) {//进入右森林的条件矩阵 + for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); } resultRight[rightIndex] = nub; resultMatrixRight.setNub(rightIndex, 0, nub); rightIndex++; } else {//进入左森林并计算左森林结果矩阵方差 - for (int j = 0; j < size; j++) {//进入右森林的条件矩阵 + for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixLeft.setNub(leftIndex, j, conditionMatrix.getNumber(i, j)); } resultLeft[leftIndex] = nub; @@ -137,4 +138,12 @@ public class Forest extends Frequency { public void setW(double[] w) { this.w = w; } + + public Forest getForestLeft() { + return forestLeft; + } + + public Forest getForestRight() { + return forestRight; + } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index d3690e0..99a254d 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -2,6 +2,7 @@ package org.wlld.regressionForest; import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.tools.Frequency; /** * @param @@ -9,30 +10,44 @@ import org.wlld.MatrixTools.MatrixOperation; * @Author LiDaPeng * @Description 回归森林 */ -public class RegressionForest { +public class RegressionForest extends Frequency { private double[] w; private Matrix conditionMatrix;//条件矩阵 private Matrix resultMatrix;//结果矩阵 + private Forest forest; private int featureNub;//特征数量 private int xIndex = 0;//记录插入位置 + private double[] results;//结果数组 + private double min;//结果最小值 + private double max;//结果最大值 public RegressionForest(int size, int featureNub) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; w = new double[size]; + results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); + forest = new Forest(featureNub, 0.9); + forest.setW(w); + forest.setConditionMatrix(conditionMatrix); + forest.setResultMatrix(resultMatrix); } else { throw new Exception("size and featureNub too small"); } } + public void getDist(double[] feature, double result) {//获取特征误差结果 + + } + public void insertFeature(double[] feature, double result) throws Exception {//插入数据 if (feature.length == featureNub) { for (int i = 0; i < featureNub; i++) { if (i < featureNub - 1) { conditionMatrix.setNub(xIndex, i, feature[i]); } else { + results[xIndex] = result; conditionMatrix.setNub(xIndex, i, 1.0); resultMatrix.setNub(xIndex, 0, result); } @@ -43,14 +58,54 @@ public class RegressionForest { } } + public void start() throws Exception {//开始进行分段 + if (forest != null) { + double[] limit = getLimit(results); + min = limit[0]; + max = limit[1]; + start(forest); + } else { + throw new Exception("rootForest is null"); + } + } + + private void start(Forest forest) throws Exception { + forest.cut(); + Forest forestLeft = forest.getForestLeft(); + Forest forestRight = forest.getForestRight(); + if (forestLeft != null && forestRight != null) { + start(forestLeft); + start(forestRight); + } + + } + public void regression() throws Exception {//开始进行回归 - if (xIndex > 0) { - Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); - for (int i = 0; i < ws.getX(); i++) { - w[i] = ws.getNumber(i, 0); - } + if (forest != null) { + regressionTree(forest); } else { - throw new Exception("regression matrix size is zero"); + throw new Exception("rootForest is null"); + } + } + + private void regressionTree(Forest forest) throws Exception { + regression(forest); + Forest forestLeft = forest.getForestLeft(); + Forest forestRight = forest.getForestRight(); + if (forestLeft != null && forestRight != null) { + regressionTree(forestLeft); + regressionTree(forestRight); + } + + } + + private void regression(Forest forest) throws Exception { + Matrix conditionMatrix = forest.getConditionMatrix(); + Matrix resultMatrix = forest.getResultMatrix(); + double[] w = forest.getW(); + Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); + for (int i = 0; i < ws.getX(); i++) { + w[i] = ws.getNumber(i, 0); } } } \ No newline at end of file From cdeea5bcec5d87f5421063ebc6c034e57b0810a2 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Tue, 8 Sep 2020 15:19:43 +0800 Subject: [PATCH 03/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 32 +++-------- .../regressionForest/RegressionForest.java | 54 +++++++++++++++++-- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 431394d..7ba3e49 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -18,8 +18,6 @@ public class Forest extends Frequency { private Forest forestLeft;//左森林 private Forest forestRight;//右森林 private int featureSize; - private double min;//下限 - private double max;//上限 private double resultVariance;//结果矩阵方差 private double median;//结果矩阵中位数 private double shrinkParameter;//方差收缩参数 @@ -31,6 +29,10 @@ public class Forest extends Frequency { w = new double[featureSize]; } + public double getMedian() { + return median; + } + public double getResultVariance() { return resultVariance; } @@ -85,36 +87,14 @@ public class Forest extends Frequency { double leftVar = variance(resultLeft); double rightVar = variance(resultRight); double variance = resultVariance * shrinkParameter; - if (leftVar < variance || rightVar < variance) {//继续拆分 - double[] left = getLimit(resultLeft); - double[] right = getLimit(resultRight); - forestLeft.setMin(left[0]); - forestLeft.setMax(left[1]); - forestRight.setMin(right[0]); - forestRight.setMax(right[1]); - } else {//不继续拆分 + if (leftVar > variance && rightVar > variance) {//不进行拆分,回退 forestLeft = null; forestRight = null; + median = 0; } } } - public double getMin() { - return min; - } - - public void setMin(double min) { - this.min = min; - } - - public double getMax() { - return max; - } - - public void setMax(double max) { - this.max = max; - } - public Matrix getConditionMatrix() { return conditionMatrix; } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 99a254d..14f7fa7 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -24,7 +24,7 @@ public class RegressionForest extends Frequency { public RegressionForest(int size, int featureNub) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; - w = new double[size]; + w = new double[featureNub]; results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); @@ -37,8 +37,56 @@ public class RegressionForest extends Frequency { } } - public void getDist(double[] feature, double result) {//获取特征误差结果 + public double getDist(double[] feature, double result) {//获取特征误差结果 + Forest forestFinish; + if (result <= min) {//直接找下边界区域 + forestFinish = getLimitRegion(forest, false); + } else if (result >= max) {//直接找到上边界区域 + forestFinish = getLimitRegion(forest, true); + } else { + forestFinish = getRegion(forest, result); + } + //计算误差 + double[] w = forestFinish.getW(); + double sigma = 0; + for (int i = 0; i < w.length; i++) { + double nub; + if (i < w.length - 1) { + nub = w[i] * feature[i]; + } else { + nub = w[i]; + } + sigma = sigma + nub; + } + return Math.abs(result - sigma); + } + private Forest getRegion(Forest forest, double result) { + double median = forest.getMedian(); + if (median > 0) {//进行了拆分 + if (result > median) {//向右走 + forest = forest.getForestRight(); + } else {//向左走 + forest = forest.getForestLeft(); + } + return getRegion(forest, result); + } else {//没有拆分 + return forest; + } + } + + private Forest getLimitRegion(Forest forest, boolean isMax) { + Forest forestSon; + if (isMax) { + forestSon = forest.getForestRight(); + } else { + forestSon = forest.getForestLeft(); + } + if (forestSon != null) { + return getLimitRegion(forestSon, isMax); + } else { + return forest; + } } public void insertFeature(double[] feature, double result) throws Exception {//插入数据 @@ -99,7 +147,7 @@ public class RegressionForest extends Frequency { } - private void regression(Forest forest) throws Exception { + private void regression(Forest forest) throws Exception {//对分段进行线性回归 Matrix conditionMatrix = forest.getConditionMatrix(); Matrix resultMatrix = forest.getResultMatrix(); double[] w = forest.getW(); From 01564d389f7b0e09bc132fa56cba679f1c4ab0f5 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Tue, 8 Sep 2020 15:45:03 +0800 Subject: [PATCH 04/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/imageRecognition/border/Knn.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/main/java/org/wlld/imageRecognition/border/Knn.java b/src/main/java/org/wlld/imageRecognition/border/Knn.java index 68fa9a4..5afaaca 100644 --- a/src/main/java/org/wlld/imageRecognition/border/Knn.java +++ b/src/main/java/org/wlld/imageRecognition/border/Knn.java @@ -22,6 +22,13 @@ public class Knn {//KNN分类器 featureMap.remove(type); } + public void revoke(int type, int nub) {//撤销一个类别最新的 + List list = featureMap.get(type); + for (int i = 0; i < nub; i++) { + list.remove(list.size() - 1); + } + } + public void insertMatrix(Matrix vector, int tag) throws Exception { if (vector.isVector() && vector.isRowVector()) { if (featureMap.size() == 0) { From 35a5d91d379899ab43b2ffcd161a5eee8be75efc Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Wed, 9 Sep 2020 09:27:38 +0800 Subject: [PATCH 05/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/imageRecognition/border/Knn.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/main/java/org/wlld/imageRecognition/border/Knn.java b/src/main/java/org/wlld/imageRecognition/border/Knn.java index 5afaaca..81a0f48 100644 --- a/src/main/java/org/wlld/imageRecognition/border/Knn.java +++ b/src/main/java/org/wlld/imageRecognition/border/Knn.java @@ -29,6 +29,15 @@ public class Knn {//KNN分类器 } } + public int getNub(int type) {//获取该分类模型的数量 + int nub = 0; + List list = featureMap.get(type); + if (list != null) { + nub = list.size(); + } + return nub; + } + public void insertMatrix(Matrix vector, int tag) throws Exception { if (vector.isVector() && vector.isRowVector()) { if (featureMap.size() == 0) { From 69aad8e21373915e50ee532257093ee60c3da845 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Sat, 12 Sep 2020 17:55:45 +0800 Subject: [PATCH 06/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 92 +++++++++++++++++-- .../regressionForest/RegressionForest.java | 16 +++- src/test/java/coverTest/ForestTest.java | 69 ++++++++++++++ 3 files changed, 164 insertions(+), 13 deletions(-) create mode 100644 src/test/java/coverTest/ForestTest.java diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 7ba3e49..c738ff5 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -4,6 +4,9 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.tools.Frequency; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; /** @@ -21,7 +24,9 @@ public class Forest extends Frequency { private double resultVariance;//结果矩阵方差 private double median;//结果矩阵中位数 private double shrinkParameter;//方差收缩参数 + private Matrix pc;//需要映射的基 private double[] w; + private int cosSize = 10;//cos 分成几份 public Forest(int featureSize, double shrinkParameter) { this.featureSize = featureSize; @@ -41,8 +46,72 @@ public class Forest extends Frequency { this.resultVariance = resultVariance; } + //检测中位数median有多少个一样的值 + private int getEqualNub(double median, double[] dm) { + int equalNub = 0; + for (int i = 0; i < dm.length; i++) { + if (median == dm[i]) { + equalNub++; + } + } + return equalNub; + } + + private void createG() throws Exception {//生成新基 + double[] cg = new double[featureSize - 1]; + Random random = new Random(); + double sigma = 0; + for (int i = 0; i < featureSize - 1; i++) { + double rm = random.nextDouble(); + cg[i] = rm; + sigma = sigma + Math.pow(rm, 2); + } + double cosOne = 1.0D / cosSize; + double[] ag = new double[cosSize - 1]; + for (int i = 1; i < cosSize; i++) { + double cos = cosOne * i; + ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); + } + int x = (cosSize - 1) * featureSize; + pc = new Matrix(x, featureSize); + for (int i = 0; i < featureSize; i++) { + Matrix matrix = new Matrix(ag.length, featureSize); + for (int j = 0; j < ag.length; j++) { + for (int k = 0; k < featureSize; k++) { + if (k != i) { + if (k < i) { + matrix.setNub(j, k, cg[k]); + } else { + matrix.setNub(j, k, cg[k - 1]); + } + } else { + matrix.setNub(j, k, ag[j]); + } + } + } + } + } + + private void findG() throws Exception {//寻找新的切入维度 + // 先尝试从原有维度切入 + Map varMap = new HashMap<>();//保存原有维度方差 + for (int i = 0; i < featureSize; i++) { + double[] g = new double[conditionMatrix.getX()]; + for (int j = 0; j < g.length; j++) { + if (i < featureSize - 1) { + g[j] = conditionMatrix.getNumber(j, i); + } else { + g[j] = resultMatrix.getNumber(j, 0); + } + } + double var = variance(g);//计算方差 + varMap.put(i, var); + } + + } + public void cut() throws Exception { - int y = resultMatrix.getY(); + int y = resultMatrix.getX(); if (y > 4) { double[] dm = new double[y]; for (int i = 0; i < y; i++) { @@ -51,20 +120,23 @@ public class Forest extends Frequency { Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; + //检测中位数median有多少个一样的值 + int equalNub = getEqualNub(median, dm); + //System.out.println("equalNub==" + equalNub + ",y==" + y); forestLeft = new Forest(featureSize, shrinkParameter); forestRight = new Forest(featureSize, shrinkParameter); - Matrix conditionMatrixLeft = new Matrix(z, featureSize);//条件矩阵左 - Matrix conditionMatrixRight = new Matrix(y - z, featureSize);//条件矩阵右 - Matrix resultMatrixLeft = new Matrix(z, 1);//结果矩阵左 - Matrix resultMatrixRight = new Matrix(y - z, 1);//结果矩阵右 + Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右 + Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左 + Matrix resultMatrixRight = new Matrix(y - z - equalNub, 1);//结果矩阵右 forestLeft.setConditionMatrix(conditionMatrixLeft); forestLeft.setResultMatrix(resultMatrixLeft); forestRight.setConditionMatrix(conditionMatrixRight); - forestRight.setConditionMatrix(resultMatrixRight); + forestRight.setResultMatrix(resultMatrixRight); int leftIndex = 0;//左矩阵添加行数 int rightIndex = 0;//右矩阵添加行数 - double[] resultLeft = new double[z]; - double[] resultRight = new double[y - z]; + double[] resultLeft = new double[z + equalNub]; + double[] resultRight = new double[y - z - equalNub]; for (int i = 0; i < y; i++) { double nub = resultMatrix.getNumber(i, 0);//结果矩阵 if (nub > median) {//进入右森林并计算右森林结果矩阵方差 @@ -87,10 +159,14 @@ public class Forest extends Frequency { double leftVar = variance(resultLeft); double rightVar = variance(resultRight); double variance = resultVariance * shrinkParameter; + System.out.println("var==" + variance + ",leftVar==" + leftVar + ",rightVar==" + rightVar); if (leftVar > variance && rightVar > variance) {//不进行拆分,回退 forestLeft = null; forestRight = null; median = 0; + } else { + forestLeft.setResultVariance(leftVar); + forestRight.setResultVariance(rightVar); } } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 14f7fa7..1b73888 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -4,6 +4,8 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; import org.wlld.tools.Frequency; +import java.util.Arrays; + /** * @param * @DATA @@ -21,14 +23,14 @@ public class RegressionForest extends Frequency { private double min;//结果最小值 private double max;//结果最大值 - public RegressionForest(int size, int featureNub) throws Exception {//初始化 + public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; w = new double[featureNub]; results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); - forest = new Forest(featureNub, 0.9); + forest = new Forest(featureNub, shrinkParameter); forest.setW(w); forest.setConditionMatrix(conditionMatrix); forest.setResultMatrix(resultMatrix); @@ -90,7 +92,7 @@ public class RegressionForest extends Frequency { } public void insertFeature(double[] feature, double result) throws Exception {//插入数据 - if (feature.length == featureNub) { + if (feature.length == featureNub - 1) { for (int i = 0; i < featureNub; i++) { if (i < featureNub - 1) { conditionMatrix.setNub(xIndex, i, feature[i]); @@ -106,8 +108,10 @@ public class RegressionForest extends Frequency { } } - public void start() throws Exception {//开始进行分段 + public void startStudy() throws Exception {//开始进行分段 if (forest != null) { + //计算方差 + forest.setResultVariance(variance(results)); double[] limit = getLimit(results); min = limit[0]; max = limit[1]; @@ -150,10 +154,12 @@ public class RegressionForest extends Frequency { private void regression(Forest forest) throws Exception {//对分段进行线性回归 Matrix conditionMatrix = forest.getConditionMatrix(); Matrix resultMatrix = forest.getResultMatrix(); - double[] w = forest.getW(); Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); + double[] w = forest.getW(); for (int i = 0; i < ws.getX(); i++) { w[i] = ws.getNumber(i, 0); } + System.out.println(Arrays.toString(w)); + System.out.println("=========================="); } } \ No newline at end of file diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java new file mode 100644 index 0000000..a01b63d --- /dev/null +++ b/src/test/java/coverTest/ForestTest.java @@ -0,0 +1,69 @@ +package coverTest; + +import org.wlld.regressionForest.RegressionForest; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description + */ +public class ForestTest { + public static void main(String[] args) throws Exception { + test(); + } + + public static void test() throws Exception {//对分段回归进行测试 + int size = 2000; + RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); + List a = fun(0.1, 0.2, 0.3, size); + List b = fun(0.3, 0.2, 0.1, size); + for (int i = 0; i < 1000; i++) { + double[] featureA = a.get(i); + double[] featureB = b.get(i); + double[] testA = new double[]{featureA[0], featureA[1]}; + double[] testB = new double[]{featureB[0], featureB[1]}; + regressionForest.insertFeature(testA, featureA[2]); + regressionForest.insertFeature(testB, featureB[2]); + } + regressionForest.startStudy(); + regressionForest.regression();//这里进行回归 + + double sigma = 0; + for (int i = 0; i < 1000; i++) { + double[] feature = a.get(i); + double[] test = new double[]{feature[0], feature[1]}; + double dist = regressionForest.getDist(test, feature[2]); + sigma = sigma + Math.pow(dist, 2); + } + double avs = sigma / size; + System.out.println("a误差:" + avs); + sigma = 0; + for (int i = 0; i < 1000; i++) { + double[] feature = b.get(i); + double[] test = new double[]{feature[0], feature[1]}; + double dist = regressionForest.getDist(test, feature[2]); + sigma = sigma + Math.pow(dist, 2); + } + double avs2 = sigma / size; + System.out.println("b误差:" + avs2); + + } + + public static List fun(double w1, double w2, double w3, int size) {//生成假数据 + List list = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < size; i++) { + double a = random.nextDouble(); + double b = random.nextDouble(); + double c = w1 * a + w2 * b + w3; + double[] data = new double[]{a, b, c}; + list.add(data); + } + return list; + } +} From de0722e5877afc95c2a9d0cbcde074a754eb8b12 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Mon, 14 Sep 2020 17:26:47 +0800 Subject: [PATCH 07/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/MatrixTools/MatrixOperation.java | 2 +- .../org/wlld/regressionForest/Forest.java | 101 ++++++++++-------- .../regressionForest/RegressionForest.java | 72 ++++++++++++- 3 files changed, 128 insertions(+), 47 deletions(-) diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index ec5f647..32e4c82 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -243,7 +243,7 @@ public class MatrixOperation { double nub = 0; for (int i = 0; i < matrix.getX(); i++) { for (int j = 0; j < matrix.getY(); j++) { - nub = ArithUtil.add(Math.pow(matrix.getNumber(i, j), 2), nub); + nub = Math.pow(matrix.getNumber(i, j), 2) + nub; } } return Math.sqrt(nub); diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index c738ff5..fbe3c0f 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -1,12 +1,10 @@ package org.wlld.regressionForest; import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; import org.wlld.tools.Frequency; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; +import java.util.*; /** @@ -24,13 +22,16 @@ public class Forest extends Frequency { private double resultVariance;//结果矩阵方差 private double median;//结果矩阵中位数 private double shrinkParameter;//方差收缩参数 - private Matrix pc;//需要映射的基 + private Matrix pc;//需要映射的基的集合 + private Matrix pc1;//需要映射的基 private double[] w; - private int cosSize = 10;//cos 分成几份 + private boolean isOldG = true;//是否使用老基 + private int oldGId = 0;//老基的id - public Forest(int featureSize, double shrinkParameter) { + public Forest(int featureSize, double shrinkParameter, Matrix pc) { this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; + this.pc = pc; w = new double[featureSize]; } @@ -57,44 +58,22 @@ public class Forest extends Frequency { return equalNub; } - private void createG() throws Exception {//生成新基 - double[] cg = new double[featureSize - 1]; - Random random = new Random(); - double sigma = 0; - for (int i = 0; i < featureSize - 1; i++) { - double rm = random.nextDouble(); - cg[i] = rm; - sigma = sigma + Math.pow(rm, 2); - } - double cosOne = 1.0D / cosSize; - double[] ag = new double[cosSize - 1]; - for (int i = 1; i < cosSize; i++) { - double cos = cosOne * i; - ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); - } - int x = (cosSize - 1) * featureSize; - pc = new Matrix(x, featureSize); - for (int i = 0; i < featureSize; i++) { - Matrix matrix = new Matrix(ag.length, featureSize); - for (int j = 0; j < ag.length; j++) { - for (int k = 0; k < featureSize; k++) { - if (k != i) { - if (k < i) { - matrix.setNub(j, k, cg[k]); - } else { - matrix.setNub(j, k, cg[k - 1]); - } - } else { - matrix.setNub(j, k, ag[j]); - } + private void findG() throws Exception {//寻找新的切入维度 + // 先尝试从原有维度切入 + int xSize = conditionMatrix.getX(); + int ySize = conditionMatrix.getY(); + Matrix matrix = new Matrix(xSize, ySize); + for (int i = 0; i < xSize; i++) { + for (int j = 0; j < ySize; j++) { + if (j < ySize - 1) { + matrix.setNub(i, j, conditionMatrix.getNumber(i, j)); + } else { + matrix.setNub(i, j, resultMatrix.getNumber(i, 0)); } } } - } - - private void findG() throws Exception {//寻找新的切入维度 - // 先尝试从原有维度切入 - Map varMap = new HashMap<>();//保存原有维度方差 + double maxOld = 0; + int type = 0; for (int i = 0; i < featureSize; i++) { double[] g = new double[conditionMatrix.getX()]; for (int j = 0; j < g.length; j++) { @@ -105,9 +84,41 @@ public class Forest extends Frequency { } } double var = variance(g);//计算方差 - varMap.put(i, var); + if (var > maxOld) { + maxOld = var; + type = i; + } + } + int x = pc.getX(); + double max = 0; + for (int i = 0; i < x; i++) { + Matrix g = pc.getRow(i); + double gNorm = MatrixOperation.getNorm(g); + double[] var = new double[xSize]; + for (int j = 0; j < xSize; j++) { + Matrix parameter = matrix.getRow(j); + double dist = transG(g, parameter, gNorm); + var[j] = dist; + } + double variance = variance(var); + if (variance > max) { + max = variance; + pc1 = g; + } } + //找到非原始基最离散的新基 + if (max > maxOld) {//使用新基 + isOldG = false; + } else {//使用原有基 + isOldG = true; + oldGId = type; + } + } + private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基 + //先求内积 + double innerProduct = MatrixOperation.innerProduct(g, parameter); + return innerProduct / gNorm; } public void cut() throws Exception { @@ -123,8 +134,8 @@ public class Forest extends Frequency { //检测中位数median有多少个一样的值 int equalNub = getEqualNub(median, dm); //System.out.println("equalNub==" + equalNub + ",y==" + y); - forestLeft = new Forest(featureSize, shrinkParameter); - forestRight = new Forest(featureSize, shrinkParameter); + forestLeft = new Forest(featureSize, shrinkParameter, pc); + forestRight = new Forest(featureSize, shrinkParameter, pc); Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右 Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左 diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 1b73888..ae99e42 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -5,6 +5,9 @@ import org.wlld.MatrixTools.MatrixOperation; import org.wlld.tools.Frequency; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; /** * @param @@ -22,6 +25,16 @@ public class RegressionForest extends Frequency { private double[] results;//结果数组 private double min;//结果最小值 private double max;//结果最大值 + private Matrix pc;//需要映射的基 + private int cosSize = 10;//cos 分成几份 + + public int getCosSize() { + return cosSize; + } + + public void setCosSize(int cosSize) { + this.cosSize = cosSize; + } public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化 if (size > 0 && featureNub > 0) { @@ -30,7 +43,8 @@ public class RegressionForest extends Frequency { results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); - forest = new Forest(featureNub, shrinkParameter); + createG(); + forest = new Forest(featureNub, shrinkParameter, pc); forest.setW(w); forest.setConditionMatrix(conditionMatrix); forest.setResultMatrix(resultMatrix); @@ -91,6 +105,62 @@ public class RegressionForest extends Frequency { } } + private void createG() throws Exception {//生成新基 + double[] cg = new double[featureNub - 1]; + Random random = new Random(); + double sigma = 0; + for (int i = 0; i < featureNub - 1; i++) { + double rm = random.nextDouble(); + cg[i] = rm; + sigma = sigma + Math.pow(rm, 2); + } + double cosOne = 1.0D / cosSize; + double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值 + for (int i = 1; i < cosSize; i++) { + double cos = cosOne * i; + ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); + } + int x = (cosSize - 1) * featureNub; + pc = new Matrix(x, featureNub); + for (int i = 0; i < featureNub; i++) {//遍历所有的固定基 + //以某个固定基摆动的所有新基集合的矩阵 + Matrix matrix = new Matrix(ag.length, featureNub); + for (int j = 0; j < ag.length; j++) { + for (int k = 0; k < featureNub; k++) { + if (k != i) { + if (k < i) { + matrix.setNub(j, k, cg[k]); + } else { + matrix.setNub(j, k, cg[k - 1]); + } + } else { + matrix.setNub(j, k, ag[j]); + } + } + } + //将一个固定基内摆动的新基都装到最大的集合内 + int index = (cosSize - 1) * i; + push(pc, matrix, index); + } + } + + //将两个矩阵从上到下进行合并 + private void push(Matrix mother, Matrix son, int index) throws Exception { + if (mother.getY() == son.getY()) { + int x = index + son.getX(); + int y = mother.getY(); + int start = 0; + for (int i = index; i < x; i++) { + for (int j = 0; j < y; j++) { + mother.setNub(i, j, son.getNumber(start, j)); + } + start++; + } + } else { + throw new Exception("matrix Y is not equals"); + } + } + public void insertFeature(double[] feature, double result) throws Exception {//插入数据 if (feature.length == featureNub - 1) { for (int i = 0; i < featureNub; i++) { From 79ccfdf340367e8d076d8e4621eb394ddcb7e892 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Wed, 16 Sep 2020 17:08:04 +0800 Subject: [PATCH 08/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 66 +++++++++++++++---- .../regressionForest/RegressionForest.java | 6 +- src/test/java/coverTest/ForestTest.java | 11 ++-- 3 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index fbe3c0f..9127010 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -27,6 +27,8 @@ public class Forest extends Frequency { private double[] w; private boolean isOldG = true;//是否使用老基 private int oldGId = 0;//老基的id + private Matrix matrixAll;//全矩阵 + private double gNorm;//新维度的摸 public Forest(int featureSize, double shrinkParameter, Matrix pc) { this.featureSize = featureSize; @@ -58,17 +60,17 @@ public class Forest extends Frequency { return equalNub; } - private void findG() throws Exception {//寻找新的切入维度 + private double[] findG() throws Exception {//寻找新的切入维度 // 先尝试从原有维度切入 int xSize = conditionMatrix.getX(); int ySize = conditionMatrix.getY(); - Matrix matrix = new Matrix(xSize, ySize); + matrixAll = new Matrix(xSize, ySize); for (int i = 0; i < xSize; i++) { for (int j = 0; j < ySize; j++) { if (j < ySize - 1) { - matrix.setNub(i, j, conditionMatrix.getNumber(i, j)); + matrixAll.setNub(i, j, conditionMatrix.getNumber(i, j)); } else { - matrix.setNub(i, j, resultMatrix.getNumber(i, 0)); + matrixAll.setNub(i, j, resultMatrix.getNumber(i, 0)); } } } @@ -83,7 +85,7 @@ public class Forest extends Frequency { g[j] = resultMatrix.getNumber(j, 0); } } - double var = variance(g);//计算方差 + double var = dc(g);//计算方差 if (var > maxOld) { maxOld = var; type = i; @@ -96,11 +98,11 @@ public class Forest extends Frequency { double gNorm = MatrixOperation.getNorm(g); double[] var = new double[xSize]; for (int j = 0; j < xSize; j++) { - Matrix parameter = matrix.getRow(j); + Matrix parameter = matrixAll.getRow(j); double dist = transG(g, parameter, gNorm); var[j] = dist; } - double variance = variance(var); + double variance = dc(var); if (variance > max) { max = variance; pc1 = g; @@ -113,6 +115,7 @@ public class Forest extends Frequency { isOldG = true; oldGId = type; } + return findTwo(xSize); } private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基 @@ -121,19 +124,42 @@ public class Forest extends Frequency { return innerProduct / gNorm; } + private double[] findTwo(int dataSize) throws Exception { + Matrix matrix;//创建一个列向量 + double[] data = new double[dataSize]; + if (isOldG) {//使用原有基 + if (oldGId == featureSize - 1) {//从结果矩阵提取数据 + matrix = resultMatrix; + } else {//从条件矩阵中提取数据 + matrix = conditionMatrix.getColumn(oldGId); + } + //将数据塞入数组 + for (int i = 0; i < dataSize; i++) { + data[i] = matrix.getNumber(i, 0); + } + } else {//使用转换基 + int x = matrixAll.getX(); + gNorm = MatrixOperation.getNorm(pc1); + for (int i = 0; i < x; i++) { + Matrix parameter = matrixAll.getRow(i); + double dist = transG(pc1, parameter, gNorm); + data[i] = dist; + } + } + Arrays.sort(data);//对数据进行排序 + return data; + } + public void cut() throws Exception { int y = resultMatrix.getX(); - if (y > 4) { - double[] dm = new double[y]; - for (int i = 0; i < y; i++) { - dm[i] = resultMatrix.getNumber(i, 0); - } + if (y > 8) { + double[] dm = findG(); Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; //检测中位数median有多少个一样的值 int equalNub = getEqualNub(median, dm); - //System.out.println("equalNub==" + equalNub + ",y==" + y); + //////////// forestLeft = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc); Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 @@ -148,8 +174,20 @@ public class Forest extends Frequency { int rightIndex = 0;//右矩阵添加行数 double[] resultLeft = new double[z + equalNub]; double[] resultRight = new double[y - z - equalNub]; + ////// for (int i = 0; i < y; i++) { - double nub = resultMatrix.getNumber(i, 0);//结果矩阵 + double nub; + if (isOldG) {//使用原有基 + if (oldGId == featureSize - 1) {//从结果矩阵提取数据 + nub = resultMatrix.getNumber(i, 0); + } else {//从条件矩阵中提取数据 + nub = conditionMatrix.getNumber(i, oldGId); + } + } else {//使用新基 + Matrix parameter = matrixAll.getRow(i); + nub = transG(pc1, parameter, gNorm); + } + //double nub = resultMatrix.getNumber(i, 0);//结果矩阵 if (nub > median) {//进入右森林并计算右森林结果矩阵方差 for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index ae99e42..00d9a6c 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -26,7 +26,7 @@ public class RegressionForest extends Frequency { private double min;//结果最小值 private double max;//结果最大值 private Matrix pc;//需要映射的基 - private int cosSize = 10;//cos 分成几份 + private int cosSize = 20;//cos 分成几份 public int getCosSize() { return cosSize; @@ -116,8 +116,8 @@ public class RegressionForest extends Frequency { } double cosOne = 1.0D / cosSize; double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值 - for (int i = 1; i < cosSize; i++) { - double cos = cosOne * i; + for (int i = 0; i < cosSize - 1; i++) { + double cos = cosOne * (i + 1); ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); } int x = (cosSize - 1) * featureNub; diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index a01b63d..66cd738 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -20,8 +20,9 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); - List a = fun(0.1, 0.2, 0.3, size); - List b = fun(0.3, 0.2, 0.1, size); + regressionForest.setCosSize(40); + List a = fun(0.1, 0.2, 0.3, size, 2, 1); + List b = fun(0.3, 0.2, 0.1, size, 2, 2); for (int i = 0; i < 1000; i++) { double[] featureA = a.get(i); double[] featureB = b.get(i); @@ -54,12 +55,14 @@ public class ForestTest { } - public static List fun(double w1, double w2, double w3, int size) {//生成假数据 + public static List fun(double w1, double w2, double w3, int size, int region, int index) {//生成假数据 List list = new ArrayList<>(); Random random = new Random(); + int nub = (index - 1) * 100; + double max = region * 100; for (int i = 0; i < size; i++) { + double b = (double) (random.nextInt(100) + nub) / max; double a = random.nextDouble(); - double b = random.nextDouble(); double c = w1 * a + w2 * b + w3; double[] data = new double[]{a, b, c}; list.add(data); From 9e72af8212366441180475670ec74b21a064acd0 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Thu, 17 Sep 2020 12:10:14 +0800 Subject: [PATCH 09/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 35 +++++-------------- .../regressionForest/RegressionForest.java | 5 +-- src/test/java/coverTest/ForestTest.java | 2 -- 3 files changed, 11 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 9127010..50a50f1 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -150,6 +150,10 @@ public class Forest extends Frequency { return data; } + public void pruning() {//进行剪枝 + + } + public void cut() throws Exception { int y = resultMatrix.getX(); if (y > 8) { @@ -159,7 +163,6 @@ public class Forest extends Frequency { median = dm[z]; //检测中位数median有多少个一样的值 int equalNub = getEqualNub(median, dm); - //////////// forestLeft = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc); Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 @@ -172,51 +175,29 @@ public class Forest extends Frequency { forestRight.setResultMatrix(resultMatrixRight); int leftIndex = 0;//左矩阵添加行数 int rightIndex = 0;//右矩阵添加行数 - double[] resultLeft = new double[z + equalNub]; - double[] resultRight = new double[y - z - equalNub]; - ////// for (int i = 0; i < y; i++) { double nub; if (isOldG) {//使用原有基 - if (oldGId == featureSize - 1) {//从结果矩阵提取数据 - nub = resultMatrix.getNumber(i, 0); - } else {//从条件矩阵中提取数据 - nub = conditionMatrix.getNumber(i, oldGId); - } + nub = matrixAll.getNumber(i, oldGId); } else {//使用新基 Matrix parameter = matrixAll.getRow(i); nub = transG(pc1, parameter, gNorm); } - //double nub = resultMatrix.getNumber(i, 0);//结果矩阵 if (nub > median) {//进入右森林并计算右森林结果矩阵方差 for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); } - resultRight[rightIndex] = nub; - resultMatrixRight.setNub(rightIndex, 0, nub); + resultMatrixRight.setNub(rightIndex, 0, resultMatrix.getNumber(i, 0)); rightIndex++; } else {//进入左森林并计算左森林结果矩阵方差 for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixLeft.setNub(leftIndex, j, conditionMatrix.getNumber(i, j)); } - resultLeft[leftIndex] = nub; - resultMatrixLeft.setNub(leftIndex, 0, nub); + resultMatrixLeft.setNub(leftIndex, 0, resultMatrix.getNumber(i, 0)); leftIndex++; } } - //分区完成,计算两棵树结果矩阵的方差 - double leftVar = variance(resultLeft); - double rightVar = variance(resultRight); - double variance = resultVariance * shrinkParameter; - System.out.println("var==" + variance + ",leftVar==" + leftVar + ",rightVar==" + rightVar); - if (leftVar > variance && rightVar > variance) {//不进行拆分,回退 - forestLeft = null; - forestRight = null; - median = 0; - } else { - forestLeft.setResultVariance(leftVar); - forestRight.setResultVariance(rightVar); - } + //分区完成 } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 00d9a6c..85850da 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -186,6 +186,8 @@ public class RegressionForest extends Frequency { min = limit[0]; max = limit[1]; start(forest); + //进行回归 + regression(); } else { throw new Exception("rootForest is null"); } @@ -199,10 +201,9 @@ public class RegressionForest extends Frequency { start(forestLeft); start(forestRight); } - } - public void regression() throws Exception {//开始进行回归 + private void regression() throws Exception {//开始进行回归 if (forest != null) { regressionTree(forest); } else { diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index 66cd738..abaaa0b 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -32,8 +32,6 @@ public class ForestTest { regressionForest.insertFeature(testB, featureB[2]); } regressionForest.startStudy(); - regressionForest.regression();//这里进行回归 - double sigma = 0; for (int i = 0; i < 1000; i++) { double[] feature = a.get(i); From 5d53a3415d8660fbd2e12ea55d8f6c287e7336bb Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Thu, 17 Sep 2020 15:42:39 +0800 Subject: [PATCH 10/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 24 ++++++++++++++++++- .../regressionForest/RegressionForest.java | 20 ++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 50a50f1..c4e1af8 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -150,7 +150,29 @@ public class Forest extends Frequency { return data; } - public void pruning() {//进行剪枝 + private double getDist(double[] data) { + int len = data.length; + double sigma = 0; + for (int i = 0; i < len; i++) { + double sub = data[i] - w[i]; + sigma = sigma + Math.pow(sub, 2); + } + return sigma / len; + } + + public void pruning() {//进行后剪枝 + if (forestLeft != null) { + double leftDist = getDist(forestLeft.getW()); + if (leftDist < shrinkParameter) {//剪枝 + forestLeft = null; + } + } + if (forestRight != null) { + double rightDist = getDist(forestRight.getW()); + if (rightDist < shrinkParameter) { + forestRight = null; + } + } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 85850da..306c770 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -188,6 +188,8 @@ public class RegressionForest extends Frequency { start(forest); //进行回归 regression(); + //进行剪枝 + pruning(); } else { throw new Exception("rootForest is null"); } @@ -203,6 +205,24 @@ public class RegressionForest extends Frequency { } } + private void pruning() throws Exception { + if (forest != null) { + pruningTree(forest); + } else { + throw new Exception("rootForest is null"); + } + } + + private void pruningTree(Forest forest) { + if (forest != null) { + forest.pruning(); + Forest forestRight = forest.getForestRight(); + pruningTree(forestRight); + Forest forestLeft = forest.getForestLeft(); + pruningTree(forestLeft); + } + } + private void regression() throws Exception {//开始进行回归 if (forest != null) { regressionTree(forest); From c13848884b1c886ae0a489096c04a38389329b46 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Thu, 17 Sep 2020 17:30:34 +0800 Subject: [PATCH 11/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/regressionForest/RegressionForest.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 306c770..1b838eb 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -250,7 +250,6 @@ public class RegressionForest extends Frequency { for (int i = 0; i < ws.getX(); i++) { w[i] = ws.getNumber(i, 0); } - System.out.println(Arrays.toString(w)); - System.out.println("=========================="); + } } \ No newline at end of file From 7a4e006cec64bdfd11d982a76d90246280ee88da Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Fri, 18 Sep 2020 13:17:08 +0800 Subject: [PATCH 12/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 35 +++++++++---------- .../regressionForest/RegressionForest.java | 14 ++++---- src/test/java/coverTest/ForestTest.java | 2 +- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index c4e1af8..da4ebb1 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -49,17 +49,6 @@ public class Forest extends Frequency { this.resultVariance = resultVariance; } - //检测中位数median有多少个一样的值 - private int getEqualNub(double median, double[] dm) { - int equalNub = 0; - for (int i = 0; i < dm.length; i++) { - if (median == dm[i]) { - equalNub++; - } - } - return equalNub; - } - private double[] findG() throws Exception {//寻找新的切入维度 // 先尝试从原有维度切入 int xSize = conditionMatrix.getX(); @@ -163,12 +152,14 @@ public class Forest extends Frequency { public void pruning() {//进行后剪枝 if (forestLeft != null) { double leftDist = getDist(forestLeft.getW()); + System.out.println("左剪枝阈值:" + leftDist); if (leftDist < shrinkParameter) {//剪枝 forestLeft = null; } } if (forestRight != null) { double rightDist = getDist(forestRight.getW()); + System.out.println("右剪枝阈值:" + rightDist); if (rightDist < shrinkParameter) { forestRight = null; } @@ -178,19 +169,25 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); - if (y > 8) { + if (y > 200) { double[] dm = findG(); - Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; - //检测中位数median有多少个一样的值 - int equalNub = getEqualNub(median, dm); + int rightNub = 0; + int leftNub = 0; + for (int i = 0; i < dm.length; i++) { + if (dm[i] > median) { + rightNub++; + } else { + leftNub++; + } + } forestLeft = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc); - Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 - Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右 - Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左 - Matrix resultMatrixRight = new Matrix(y - z - equalNub, 1);//结果矩阵右 + Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(rightNub, featureSize);//条件矩阵右 + Matrix resultMatrixLeft = new Matrix(leftNub, 1);//结果矩阵左 + Matrix resultMatrixRight = new Matrix(rightNub, 1);//结果矩阵右 forestLeft.setConditionMatrix(conditionMatrixLeft); forestLeft.setResultMatrix(resultMatrixLeft); forestRight.setConditionMatrix(conditionMatrixRight); diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 1b838eb..26ada20 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -79,16 +79,14 @@ public class RegressionForest extends Frequency { private Forest getRegion(Forest forest, double result) { double median = forest.getMedian(); - if (median > 0) {//进行了拆分 - if (result > median) {//向右走 - forest = forest.getForestRight(); - } else {//向左走 - forest = forest.getForestLeft(); - } - return getRegion(forest, result); - } else {//没有拆分 + if (result > median && forest.getForestRight() != null) {//向右走 + forest = forest.getForestRight(); + } else if (result <= median && forest.getForestLeft() != null) {//向左走 + forest = forest.getForestLeft(); + } else { return forest; } + return getRegion(forest, result); } private Forest getLimitRegion(Forest forest, boolean isMax) { diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index abaaa0b..f85a0ac 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -19,7 +19,7 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; - RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); + RegressionForest regressionForest = new RegressionForest(size, 3, 0.02); regressionForest.setCosSize(40); List a = fun(0.1, 0.2, 0.3, size, 2, 1); List b = fun(0.3, 0.2, 0.1, size, 2, 2); From c7b00adf0ce2449368248a7be50dee7a95936b49 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Fri, 18 Sep 2020 14:02:55 +0800 Subject: [PATCH 13/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/org/wlld/regressionForest/Forest.java | 8 +++++--- .../wlld/regressionForest/RegressionForest.java | 2 +- src/test/java/coverTest/ForestTest.java | 15 +++++++++------ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index da4ebb1..d905124 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -97,7 +97,7 @@ public class Forest extends Frequency { pc1 = g; } } - //找到非原始基最离散的新基 + //找到非原始基最离散的新基: if (max > maxOld) {//使用新基 isOldG = false; } else {//使用原有基 @@ -150,17 +150,18 @@ public class Forest extends Frequency { } public void pruning() {//进行后剪枝 + System.out.println("执行了剪枝===="); if (forestLeft != null) { double leftDist = getDist(forestLeft.getW()); - System.out.println("左剪枝阈值:" + leftDist); if (leftDist < shrinkParameter) {//剪枝 + System.out.println("成功左剪枝阈值:" + leftDist + ",阈值:" + shrinkParameter); forestLeft = null; } } if (forestRight != null) { double rightDist = getDist(forestRight.getW()); - System.out.println("右剪枝阈值:" + rightDist); if (rightDist < shrinkParameter) { + System.out.println("成功右剪枝阈值:" + rightDist + ",阈值:" + shrinkParameter); forestRight = null; } } @@ -170,6 +171,7 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); if (y > 200) { + System.out.println("-======================"); double[] dm = findG(); int z = y / 2; median = dm[z]; diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 26ada20..f334ce1 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -215,8 +215,8 @@ public class RegressionForest extends Frequency { if (forest != null) { forest.pruning(); Forest forestRight = forest.getForestRight(); - pruningTree(forestRight); Forest forestLeft = forest.getForestLeft(); + pruningTree(forestRight); pruningTree(forestLeft); } } diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index f85a0ac..188ede0 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -19,10 +19,10 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; - RegressionForest regressionForest = new RegressionForest(size, 3, 0.02); + RegressionForest regressionForest = new RegressionForest(size, 3, 0.01); regressionForest.setCosSize(40); List a = fun(0.1, 0.2, 0.3, size, 2, 1); - List b = fun(0.3, 0.2, 0.1, size, 2, 2); + List b = fun(0.7, 0.3, 0.1, size, 2, 2); for (int i = 0; i < 1000; i++) { double[] featureA = a.get(i); double[] featureB = b.get(i); @@ -32,21 +32,24 @@ public class ForestTest { regressionForest.insertFeature(testB, featureB[2]); } regressionForest.startStudy(); + /// + List a1 = fun(0.1, 0.2, 0.3, size, 2, 1); + List b1 = fun(0.3, 0.2, 0.6, size, 2, 2); double sigma = 0; for (int i = 0; i < 1000; i++) { - double[] feature = a.get(i); + double[] feature = a1.get(i); double[] test = new double[]{feature[0], feature[1]}; double dist = regressionForest.getDist(test, feature[2]); - sigma = sigma + Math.pow(dist, 2); + sigma = sigma + dist; } double avs = sigma / size; System.out.println("a误差:" + avs); sigma = 0; for (int i = 0; i < 1000; i++) { - double[] feature = b.get(i); + double[] feature = b1.get(i); double[] test = new double[]{feature[0], feature[1]}; double dist = regressionForest.getDist(test, feature[2]); - sigma = sigma + Math.pow(dist, 2); + sigma = sigma + dist; } double avs2 = sigma / size; System.out.println("b误差:" + avs2); From 91f8e72c5be732ec9574519d7f821107b0d32f5c Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Mon, 21 Sep 2020 17:17:44 +0800 Subject: [PATCH 14/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 75 +++++++++++++++---- .../regressionForest/RegressionForest.java | 32 ++++---- src/test/java/coverTest/ForestTest.java | 18 +++-- 3 files changed, 91 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index d905124..db17fe9 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -29,12 +29,20 @@ public class Forest extends Frequency { private int oldGId = 0;//老基的id private Matrix matrixAll;//全矩阵 private double gNorm;//新维度的摸 + private Forest father;//父级 + private Map forestMap;//尽头列表 + private int id;//本节点的id + private boolean isRemove = false;//是否已经被移除了 + private boolean notRemovable = false;//不可移除 - public Forest(int featureSize, double shrinkParameter, Matrix pc) { + public Forest(int featureSize, double shrinkParameter, Matrix pc, Map forestMap + , int id) { this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; this.pc = pc; w = new double[featureSize]; + this.forestMap = forestMap; + this.id = id; } public double getMedian() { @@ -139,7 +147,7 @@ public class Forest extends Frequency { return data; } - private double getDist(double[] data) { + private double getDist(double[] data, double[] w) { int len = data.length; double sigma = 0; for (int i = 0; i < len; i++) { @@ -149,23 +157,30 @@ public class Forest extends Frequency { return sigma / len; } - public void pruning() {//进行后剪枝 + public void pruning() {//进行后剪枝,跟父级进行比较 System.out.println("执行了剪枝===="); - if (forestLeft != null) { - double leftDist = getDist(forestLeft.getW()); - if (leftDist < shrinkParameter) {//剪枝 - System.out.println("成功左剪枝阈值:" + leftDist + ",阈值:" + shrinkParameter); - forestLeft = null; + if (!notRemovable) { + Forest fatherForest = this.getFather(); + double[] fatherW = fatherForest.getW(); + double sub = getDist(w, fatherW); + if (sub < shrinkParameter) {//需要剪枝,通知父级 + + } else {//通知父级,不需要剪枝,并将父级改为不可移除 + } } - if (forestRight != null) { - double rightDist = getDist(forestRight.getW()); - if (rightDist < shrinkParameter) { - System.out.println("成功右剪枝阈值:" + rightDist + ",阈值:" + shrinkParameter); + } + + public void getSonMessage(boolean isPruning, int myId) {//进行剪枝 + if (isPruning) {//剪枝 + if (myId == id * 2) {//左节点 + forestLeft = null; + } else {//右节点 forestRight = null; } + } else {//不剪枝,将自己变为不可剪枝状态 + notRemovable = true; } - } public void cut() throws Exception { @@ -184,8 +199,14 @@ public class Forest extends Frequency { leftNub++; } } - forestLeft = new Forest(featureSize, shrinkParameter, pc); - forestRight = new Forest(featureSize, shrinkParameter, pc); + int leftId = 2 * id; + int rightId = leftId + 1; + forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId); + forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId); + forestMap.put(leftId, forestLeft); + forestMap.put(rightId, forestRight); + forestRight.setFather(this); + forestLeft.setFather(this); Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 Matrix conditionMatrixRight = new Matrix(rightNub, featureSize);//条件矩阵右 Matrix resultMatrixLeft = new Matrix(leftNub, 1);//结果矩阵左 @@ -253,4 +274,28 @@ public class Forest extends Frequency { public Forest getForestRight() { return forestRight; } + + public Forest getFather() { + return father; + } + + public void setFather(Forest father) { + this.father = father; + } + + public boolean isRemove() { + return isRemove; + } + + public void setRemove(boolean remove) { + isRemove = remove; + } + + public boolean isNotRemovable() { + return notRemovable; + } + + public void setNotRemovable(boolean notRemovable) { + this.notRemovable = notRemovable; + } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index f334ce1..9ee4aff 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -4,10 +4,7 @@ import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; import org.wlld.tools.Frequency; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; +import java.util.*; /** * @param @@ -27,6 +24,7 @@ public class RegressionForest extends Frequency { private double max;//结果最大值 private Matrix pc;//需要映射的基 private int cosSize = 20;//cos 分成几份 + private TreeMap forestMap = new TreeMap<>();//节点列表 public int getCosSize() { return cosSize; @@ -44,7 +42,8 @@ public class RegressionForest extends Frequency { conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); createG(); - forest = new Forest(featureNub, shrinkParameter, pc); + forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1); + forestMap.put(1, forest); forest.setW(w); forest.setConditionMatrix(conditionMatrix); forest.setResultMatrix(resultMatrix); @@ -203,21 +202,26 @@ public class RegressionForest extends Frequency { } } - private void pruning() throws Exception { - if (forest != null) { - pruningTree(forest); - } else { - throw new Exception("rootForest is null"); + private void pruning() throws Exception {//剪枝 + //先获取当前最大id + int max = forestMap.lastKey(); + int layersNub = (int) (Math.log(max) / Math.log(2)) + 1;//当前的层数 + int lastMin = (int) Math.pow(2, layersNub - 1);//最后一层最小的id + if (layersNub > 1) {//先遍历最后一层 + for (Map.Entry entry : forestMap.entrySet()) { + if (entry.getKey() >= lastMin) { + Forest forest = entry.getValue(); + + } + } } } private void pruningTree(Forest forest) { if (forest != null) { forest.pruning(); - Forest forestRight = forest.getForestRight(); - Forest forestLeft = forest.getForestLeft(); - pruningTree(forestRight); - pruningTree(forestLeft); + Forest father = forest.getFather(); + pruningTree(father); } } diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index 188ede0..c165adb 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -1,10 +1,9 @@ package coverTest; +import org.wlld.randomForest.Tree; import org.wlld.regressionForest.RegressionForest; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; +import java.util.*; /** * @param @@ -14,7 +13,16 @@ import java.util.Random; */ public class ForestTest { public static void main(String[] args) throws Exception { - test(); + //test(); + //int a = (int) (Math.log(4) / Math.log(2));//id22是第几层 + //double a = Math.pow(2, 5) - 1; 第五层的第一个数 + // System.out.println("a==" + a); + TreeMap map = new TreeMap<>(); + map.put(5, "a"); + map.put(3, "b"); + map.put(4, "c"); + map.put(6, "d"); + map.put(7, "e"); } public static void test() throws Exception {//对分段回归进行测试 @@ -34,7 +42,7 @@ public class ForestTest { regressionForest.startStudy(); /// List a1 = fun(0.1, 0.2, 0.3, size, 2, 1); - List b1 = fun(0.3, 0.2, 0.6, size, 2, 2); + List b1 = fun(0.7, 0.3, 0.1, size, 2, 2); double sigma = 0; for (int i = 0; i < 1000; i++) { double[] feature = a1.get(i); From 2fdc4ba80d9ae93f2f528094ea4eaf9fc7854584 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Tue, 22 Sep 2020 11:58:10 +0800 Subject: [PATCH 15/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5?= =?UTF-8?q?=E5=9B=9E=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 6 +-- .../regressionForest/RegressionForest.java | 37 ++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index db17fe9..de1920e 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -158,15 +158,15 @@ public class Forest extends Frequency { } public void pruning() {//进行后剪枝,跟父级进行比较 - System.out.println("执行了剪枝===="); if (!notRemovable) { Forest fatherForest = this.getFather(); double[] fatherW = fatherForest.getW(); double sub = getDist(w, fatherW); if (sub < shrinkParameter) {//需要剪枝,通知父级 - + fatherForest.getSonMessage(true, id); + isRemove = true; } else {//通知父级,不需要剪枝,并将父级改为不可移除 - + fatherForest.getSonMessage(false, id); } } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 9ee4aff..43a0965 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -205,23 +205,40 @@ public class RegressionForest extends Frequency { private void pruning() throws Exception {//剪枝 //先获取当前最大id int max = forestMap.lastKey(); - int layersNub = (int) (Math.log(max) / Math.log(2)) + 1;//当前的层数 - int lastMin = (int) Math.pow(2, layersNub - 1);//最后一层最小的id + int layersNub = (int) (Math.log(max) / Math.log(2));//当前的层数 + int lastMin = (int) Math.pow(2, layersNub);//最后一层最小的id if (layersNub > 1) {//先遍历最后一层 for (Map.Entry entry : forestMap.entrySet()) { if (entry.getKey() >= lastMin) { Forest forest = entry.getValue(); - + forest.pruning(); } } } - } - - private void pruningTree(Forest forest) { - if (forest != null) { - forest.pruning(); - Forest father = forest.getFather(); - pruningTree(father); + //每一层从下到上进行剪枝 + for (int i = layersNub - 1; i > 0; i++) { + int min = (int) Math.pow(2, i);//最后一层最小的id + int maxNub = (int) Math.pow(2, i + 1); + for (Map.Entry entry : forestMap.entrySet()) { + int key = entry.getKey(); + if (key >= min && key < maxNub) {//在范围内,进行剪枝 + entry.getValue().pruning(); + } else if (key >= maxNub) { + break; + } + } + } + //遍历所有节点,将删除的节点移除 + List list = new ArrayList<>(); + for (Map.Entry entry : forestMap.entrySet()) { + int key = entry.getKey(); + Forest forest = entry.getValue(); + if (forest.isRemove()) { + list.add(key); + } + } + for (int key : list) { + forestMap.remove(key); } } From dfe0b9576dcd2c20d8b0b7f3521d9c6aefc42189 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Tue, 22 Sep 2020 13:50:32 +0800 Subject: [PATCH 16/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=8C=E5=8F=89?= =?UTF-8?q?=E6=A0=91=E5=9B=9E=E5=BD=92=E5=8F=8A=E5=90=8E=E5=89=AA=E6=9E=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/regressionForest/Forest.java | 6 +++--- .../org/wlld/regressionForest/RegressionForest.java | 2 +- src/test/java/coverTest/ForestTest.java | 11 ++++------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index de1920e..4158865 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -165,6 +165,7 @@ public class Forest extends Frequency { if (sub < shrinkParameter) {//需要剪枝,通知父级 fatherForest.getSonMessage(true, id); isRemove = true; + //System.out.println("剪枝id==" + id + ",sub==" + sub + ",th==" + shrinkParameter); } else {//通知父级,不需要剪枝,并将父级改为不可移除 fatherForest.getSonMessage(false, id); } @@ -186,7 +187,6 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); if (y > 200) { - System.out.println("-======================"); double[] dm = findG(); int z = y / 2; median = dm[z]; @@ -201,10 +201,10 @@ public class Forest extends Frequency { } int leftId = 2 * id; int rightId = leftId + 1; + //System.out.println("id:" + id + ",size:" + dm.length); + forestMap.put(id, this); forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId); forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId); - forestMap.put(leftId, forestLeft); - forestMap.put(rightId, forestRight); forestRight.setFather(this); forestLeft.setFather(this); Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 43a0965..5d1ab13 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -216,7 +216,7 @@ public class RegressionForest extends Frequency { } } //每一层从下到上进行剪枝 - for (int i = layersNub - 1; i > 0; i++) { + for (int i = layersNub - 1; i > 0; i--) { int min = (int) Math.pow(2, i);//最后一层最小的id int maxNub = (int) Math.pow(2, i + 1); for (Map.Entry entry : forestMap.entrySet()) { diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index c165adb..adcc746 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -13,16 +13,11 @@ import java.util.*; */ public class ForestTest { public static void main(String[] args) throws Exception { - //test(); + test(); //int a = (int) (Math.log(4) / Math.log(2));//id22是第几层 //double a = Math.pow(2, 5) - 1; 第五层的第一个数 // System.out.println("a==" + a); - TreeMap map = new TreeMap<>(); - map.put(5, "a"); - map.put(3, "b"); - map.put(4, "c"); - map.put(6, "d"); - map.put(7, "e"); + } public static void test() throws Exception {//对分段回归进行测试 @@ -52,6 +47,8 @@ public class ForestTest { } double avs = sigma / size; System.out.println("a误差:" + avs); +// a误差:0.0017585065712555645 +// b误差:0.00761733737464547 sigma = 0; for (int i = 0; i < 1000; i++) { double[] feature = b1.get(i); From 58db5df18cb5d2c211c47570d02001cfa75f93cf Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Tue, 22 Sep 2020 14:21:35 +0800 Subject: [PATCH 17/17] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=8C=E5=8F=89?= =?UTF-8?q?=E6=A0=91=E5=9B=9E=E5=BD=92=E5=8F=8A=E5=90=8E=E5=89=AA=E6=9E=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/regressionForest/Forest.java | 10 ++++++---- .../org/wlld/regressionForest/RegressionForest.java | 4 ++-- src/test/java/coverTest/ForestTest.java | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 4158865..f3f12b4 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -34,15 +34,17 @@ public class Forest extends Frequency { private int id;//本节点的id private boolean isRemove = false;//是否已经被移除了 private boolean notRemovable = false;//不可移除 + private int minGrain;//最小粒度 public Forest(int featureSize, double shrinkParameter, Matrix pc, Map forestMap - , int id) { + , int id, int minGrain) { this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; this.pc = pc; w = new double[featureSize]; this.forestMap = forestMap; this.id = id; + this.minGrain = minGrain; } public double getMedian() { @@ -186,7 +188,7 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); - if (y > 200) { + if (y > minGrain) { double[] dm = findG(); int z = y / 2; median = dm[z]; @@ -203,8 +205,8 @@ public class Forest extends Frequency { int rightId = leftId + 1; //System.out.println("id:" + id + ",size:" + dm.length); forestMap.put(id, this); - forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId); - forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId); + forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId, minGrain); + forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId, minGrain); forestRight.setFather(this); forestLeft.setFather(this); Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 5d1ab13..95b8ba3 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -34,7 +34,7 @@ public class RegressionForest extends Frequency { this.cosSize = cosSize; } - public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化 + public RegressionForest(int size, int featureNub, double shrinkParameter, int minGrain) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; w = new double[featureNub]; @@ -42,7 +42,7 @@ public class RegressionForest extends Frequency { conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); createG(); - forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1); + forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1, minGrain); forestMap.put(1, forest); forest.setW(w); forest.setConditionMatrix(conditionMatrix); diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index adcc746..e3dacd5 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -22,7 +22,7 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; - RegressionForest regressionForest = new RegressionForest(size, 3, 0.01); + RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200); regressionForest.setCosSize(40); List a = fun(0.1, 0.2, 0.3, size, 2, 1); List b = fun(0.7, 0.3, 0.1, size, 2, 2);