diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index da4ebb1..d905124 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -97,7 +97,7 @@ public class Forest extends Frequency { pc1 = g; } } - //找到非原始基最离散的新基 + //找到非原始基最离散的新基: if (max > maxOld) {//使用新基 isOldG = false; } else {//使用原有基 @@ -150,17 +150,18 @@ public class Forest extends Frequency { } public void pruning() {//进行后剪枝 + System.out.println("执行了剪枝===="); if (forestLeft != null) { double leftDist = getDist(forestLeft.getW()); - System.out.println("左剪枝阈值:" + leftDist); if (leftDist < shrinkParameter) {//剪枝 + System.out.println("成功左剪枝阈值:" + leftDist + ",阈值:" + shrinkParameter); forestLeft = null; } } if (forestRight != null) { double rightDist = getDist(forestRight.getW()); - System.out.println("右剪枝阈值:" + rightDist); if (rightDist < shrinkParameter) { + System.out.println("成功右剪枝阈值:" + rightDist + ",阈值:" + shrinkParameter); forestRight = null; } } @@ -170,6 +171,7 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); if (y > 200) { + System.out.println("-======================"); double[] dm = findG(); int z = y / 2; median = dm[z]; diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 26ada20..f334ce1 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -215,8 +215,8 @@ public class RegressionForest extends Frequency { if (forest != null) { forest.pruning(); Forest forestRight = forest.getForestRight(); - pruningTree(forestRight); Forest forestLeft = forest.getForestLeft(); + pruningTree(forestRight); pruningTree(forestLeft); } } diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index f85a0ac..188ede0 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -19,10 +19,10 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; - RegressionForest regressionForest = new RegressionForest(size, 3, 0.02); + RegressionForest regressionForest = new RegressionForest(size, 3, 0.01); regressionForest.setCosSize(40); List a = fun(0.1, 0.2, 0.3, size, 2, 1); - List b = fun(0.3, 0.2, 0.1, size, 2, 2); + List b = fun(0.7, 0.3, 0.1, size, 2, 2); for (int i = 0; i < 1000; i++) { double[] featureA = a.get(i); double[] featureB = b.get(i); @@ -32,21 +32,24 @@ public class ForestTest { regressionForest.insertFeature(testB, featureB[2]); } regressionForest.startStudy(); + /// + List a1 = fun(0.1, 0.2, 0.3, size, 2, 1); + List b1 = fun(0.3, 0.2, 0.6, size, 2, 2); double sigma = 0; for (int i = 0; i < 1000; i++) { - double[] feature = a.get(i); + double[] feature = a1.get(i); double[] test = new double[]{feature[0], feature[1]}; double dist = regressionForest.getDist(test, feature[2]); - sigma = sigma + Math.pow(dist, 2); + sigma = sigma + dist; } double avs = sigma / size; System.out.println("a误差:" + avs); sigma = 0; for (int i = 0; i < 1000; i++) { - double[] feature = b.get(i); + double[] feature = b1.get(i); double[] test = new double[]{feature[0], feature[1]}; double dist = regressionForest.getDist(test, feature[2]); - sigma = sigma + Math.pow(dist, 2); + sigma = sigma + dist; } double avs2 = sigma / size; System.out.println("b误差:" + avs2);