From 7a4e006cec64bdfd11d982a76d90246280ee88da Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Fri, 18 Sep 2020 13:17:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5=E5=9B=9E?= =?UTF-8?q?=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 35 +++++++++---------- .../regressionForest/RegressionForest.java | 14 ++++---- src/test/java/coverTest/ForestTest.java | 2 +- 3 files changed, 23 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index c4e1af8..da4ebb1 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -49,17 +49,6 @@ public class Forest extends Frequency { this.resultVariance = resultVariance; } - //检测中位数median有多少个一样的值 - private int getEqualNub(double median, double[] dm) { - int equalNub = 0; - for (int i = 0; i < dm.length; i++) { - if (median == dm[i]) { - equalNub++; - } - } - return equalNub; - } - private double[] findG() throws Exception {//寻找新的切入维度 // 先尝试从原有维度切入 int xSize = conditionMatrix.getX(); @@ -163,12 +152,14 @@ public class Forest extends Frequency { public void pruning() {//进行后剪枝 if (forestLeft != null) { double leftDist = getDist(forestLeft.getW()); + System.out.println("左剪枝阈值:" + leftDist); if (leftDist < shrinkParameter) {//剪枝 forestLeft = null; } } if (forestRight != null) { double rightDist = getDist(forestRight.getW()); + System.out.println("右剪枝阈值:" + rightDist); if (rightDist < shrinkParameter) { forestRight = null; } @@ -178,19 +169,25 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); - if (y > 8) { + if (y > 200) { double[] dm = findG(); - Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; - //检测中位数median有多少个一样的值 - int equalNub = getEqualNub(median, dm); + int rightNub = 0; + int leftNub = 0; + for (int i = 0; i < dm.length; i++) { + if (dm[i] > median) { + rightNub++; + } else { + leftNub++; + } + } forestLeft = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc); - Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 - Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右 - Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左 - Matrix resultMatrixRight = new Matrix(y - z - equalNub, 1);//结果矩阵右 + Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(rightNub, featureSize);//条件矩阵右 + Matrix resultMatrixLeft = new Matrix(leftNub, 1);//结果矩阵左 + Matrix resultMatrixRight = new Matrix(rightNub, 1);//结果矩阵右 forestLeft.setConditionMatrix(conditionMatrixLeft); forestLeft.setResultMatrix(resultMatrixLeft); forestRight.setConditionMatrix(conditionMatrixRight); diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 1b838eb..26ada20 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -79,16 +79,14 @@ public class RegressionForest extends Frequency { private Forest getRegion(Forest forest, double result) { double median = forest.getMedian(); - if (median > 0) {//进行了拆分 - if (result > median) {//向右走 - forest = forest.getForestRight(); - } else {//向左走 - forest = forest.getForestLeft(); - } - return getRegion(forest, result); - } else {//没有拆分 + if (result > median && forest.getForestRight() != null) {//向右走 + forest = forest.getForestRight(); + } else if (result <= median && forest.getForestLeft() != null) {//向左走 + forest = forest.getForestLeft(); + } else { return forest; } + return getRegion(forest, result); } private Forest getLimitRegion(Forest forest, boolean isMax) { diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index abaaa0b..f85a0ac 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -19,7 +19,7 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; - RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); + RegressionForest regressionForest = new RegressionForest(size, 3, 0.02); regressionForest.setCosSize(40); List a = fun(0.1, 0.2, 0.3, size, 2, 1); List b = fun(0.3, 0.2, 0.1, size, 2, 2);