diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 4158865..f3f12b4 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -34,15 +34,17 @@ public class Forest extends Frequency { private int id;//本节点的id private boolean isRemove = false;//是否已经被移除了 private boolean notRemovable = false;//不可移除 + private int minGrain;//最小粒度 public Forest(int featureSize, double shrinkParameter, Matrix pc, Map forestMap - , int id) { + , int id, int minGrain) { this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; this.pc = pc; w = new double[featureSize]; this.forestMap = forestMap; this.id = id; + this.minGrain = minGrain; } public double getMedian() { @@ -186,7 +188,7 @@ public class Forest extends Frequency { public void cut() throws Exception { int y = resultMatrix.getX(); - if (y > 200) { + if (y > minGrain) { double[] dm = findG(); int z = y / 2; median = dm[z]; @@ -203,8 +205,8 @@ public class Forest extends Frequency { int rightId = leftId + 1; //System.out.println("id:" + id + ",size:" + dm.length); forestMap.put(id, this); - forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId); - forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId); + forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId, minGrain); + forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId, minGrain); forestRight.setFather(this); forestLeft.setFather(this); Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 5d1ab13..95b8ba3 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -34,7 +34,7 @@ public class RegressionForest extends Frequency { this.cosSize = cosSize; } - public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化 + public RegressionForest(int size, int featureNub, double shrinkParameter, int minGrain) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; w = new double[featureNub]; @@ -42,7 +42,7 @@ public class RegressionForest extends Frequency { conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); createG(); - forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1); + forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1, minGrain); forestMap.put(1, forest); forest.setW(w); forest.setConditionMatrix(conditionMatrix); diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index adcc746..e3dacd5 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -22,7 +22,7 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; - RegressionForest regressionForest = new RegressionForest(size, 3, 0.01); + RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200); regressionForest.setCosSize(40); List a = fun(0.1, 0.2, 0.3, size, 2, 1); List b = fun(0.7, 0.3, 0.1, size, 2, 2);