diff --git a/src/main/java/org/wlld/randomForest/RandomForest.java b/src/main/java/org/wlld/randomForest/RandomForest.java index 28e9257..59319f6 100644 --- a/src/main/java/org/wlld/randomForest/RandomForest.java +++ b/src/main/java/org/wlld/randomForest/RandomForest.java @@ -12,6 +12,11 @@ import java.util.*; public class RandomForest { private Random random = new Random(); private Tree[] forest; + private double trustTh = 0.1;//信任阈值 + + public void setTrustTh(double trustTh) {//设置信任阈值 + this.trustTh = trustTh; + } public RandomForest(int treeNub) throws Exception { if (treeNub > 0) { @@ -67,6 +72,9 @@ public class RandomForest { nub = myNub; } } + if (nub < ArithUtil.mul(forest.length, trustTh)) { + type = 0; + } return type; } @@ -75,7 +83,7 @@ public class RandomForest { if (dataTable.getSize() > 4) { int kNub = (int) ArithUtil.div(Math.log(dataTable.getSize()), Math.log(2)); //int kNub = dataTable.getSize() - 1; - // System.out.println("knNub==" + kNub); + // System.out.println("knNub==" + kNub); for (int i = 0; i < forest.length; i++) { Tree tree = new Tree(getRandomData(dataTable, kNub)); forest[i] = tree;