From 8abe63b1fcb378f49eef8378ab07ea603eb9937d Mon Sep 17 00:00:00 2001 From: lidapeng Date: Sat, 22 Feb 2020 17:54:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=B8=80=E4=B8=AA=E8=B7=9F?= =?UTF-8?q?=E8=8A=82=E7=82=B9=E4=B8=BA=E7=A9=BA=E7=9A=84=E6=8A=A5=E9=94=99?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/randomForest/Tree.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index fc83861..1d1f5f0 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -13,7 +13,7 @@ import java.util.*; public class Tree {//决策树 private DataTable dataTable; private Map> table;//总样本 - private Node rootNode = new Node();//根节点 + private Node rootNode;//根节点 private List endList;//最终结果分类 private List lastNodes = new ArrayList<>();//最后一层节点集合 private Random random = new Random(); @@ -203,7 +203,11 @@ public class Tree {//决策树 } public int judge(Object ob) throws Exception {//进行类别判断 - return goTree(ob, rootNode); + if (rootNode != null) { + return goTree(ob, rootNode); + } else { + throw new Exception("rootNode is null"); + } } private int goTree(Object ob, Node node) throws Exception {//从树顶向下攀爬 @@ -230,6 +234,7 @@ public class Tree {//决策树 public void study() throws Exception { if (dataTable != null && dataTable.getLength() > 0) { + rootNode = new Node(); table = dataTable.getTable(); endList = dataTable.getTable().get(dataTable.getKey()); Set set = dataTable.getKeyType();