diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index fc83861..1d1f5f0 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -13,7 +13,7 @@ import java.util.*; public class Tree {//决策树 private DataTable dataTable; private Map> table;//总样本 - private Node rootNode = new Node();//根节点 + private Node rootNode;//根节点 private List endList;//最终结果分类 private List lastNodes = new ArrayList<>();//最后一层节点集合 private Random random = new Random(); @@ -203,7 +203,11 @@ public class Tree {//决策树 } public int judge(Object ob) throws Exception {//进行类别判断 - return goTree(ob, rootNode); + if (rootNode != null) { + return goTree(ob, rootNode); + } else { + throw new Exception("rootNode is null"); + } } private int goTree(Object ob, Node node) throws Exception {//从树顶向下攀爬 @@ -230,6 +234,7 @@ public class Tree {//决策树 public void study() throws Exception { if (dataTable != null && dataTable.getLength() > 0) { + rootNode = new Node(); table = dataTable.getTable(); endList = dataTable.getTable().get(dataTable.getKey()); Set set = dataTable.getKeyType();