diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 50a50f1..c4e1af8 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -150,7 +150,29 @@ public class Forest extends Frequency { return data; } - public void pruning() {//进行剪枝 + private double getDist(double[] data) { + int len = data.length; + double sigma = 0; + for (int i = 0; i < len; i++) { + double sub = data[i] - w[i]; + sigma = sigma + Math.pow(sub, 2); + } + return sigma / len; + } + + public void pruning() {//进行后剪枝 + if (forestLeft != null) { + double leftDist = getDist(forestLeft.getW()); + if (leftDist < shrinkParameter) {//剪枝 + forestLeft = null; + } + } + if (forestRight != null) { + double rightDist = getDist(forestRight.getW()); + if (rightDist < shrinkParameter) { + forestRight = null; + } + } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 85850da..306c770 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -188,6 +188,8 @@ public class RegressionForest extends Frequency { start(forest); //进行回归 regression(); + //进行剪枝 + pruning(); } else { throw new Exception("rootForest is null"); } @@ -203,6 +205,24 @@ public class RegressionForest extends Frequency { } } + private void pruning() throws Exception { + if (forest != null) { + pruningTree(forest); + } else { + throw new Exception("rootForest is null"); + } + } + + private void pruningTree(Forest forest) { + if (forest != null) { + forest.pruning(); + Forest forestRight = forest.getForestRight(); + pruningTree(forestRight); + Forest forestLeft = forest.getForestLeft(); + pruningTree(forestLeft); + } + } + private void regression() throws Exception {//开始进行回归 if (forest != null) { regressionTree(forest);