diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index d0b1a62..c62da7f 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -12,7 +12,7 @@ import java.util.*; public class Tree {//决策树 private DataTable dataTable; private Map> table;//总样本 - private Node rootNode;//根节点 + private Node rootNode = new Node();//根节点 private List endList;//最终结果分类 private class Node { @@ -29,7 +29,7 @@ public class Tree {//决策树 private double gainRatio; } - Tree(DataTable dataTable) throws Exception { + public Tree(DataTable dataTable) throws Exception { if (dataTable.getKey() != null && dataTable.getLength() > 0) { table = dataTable.getTable(); this.dataTable = dataTable; @@ -105,12 +105,11 @@ public class Tree {//决策树 Set nowAttribute = removeAttribute(attributes, name); Node sonNode = new Node(); nodeList.add(sonNode); - sonNode.key = mapEntry.getKey(); sonNode.attribute = nowAttribute; List list = entry.getValue(); sonNode.fatherList = list; int myNub = list.size(); - double ent = getEnt(list);//每一个信息熵都是一个子集 + double ent = getEnt(list); double dNub = ArithUtil.div(myNub, fatherNub); IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV); gain = getGain(ent, dNub, gain); @@ -118,20 +117,28 @@ public class Tree {//决策树 Gain gain1 = new Gain(); gainMap.put(name, gain1); gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益 - gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率 + if (IV != 0) { + gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率 + } else { + gain1.gainRatio = -1; + } sigmaG = ArithUtil.add(gain1.gain, sigmaG); i++; } - double avgGain = ArithUtil.div(sigmaG, i); + double avgGain = sigmaG / i; double gainRatio = 0;//最大增益率 String key = null;//可选属性 for (Map.Entry entry : gainMap.entrySet()) { Gain gain = entry.getValue(); - if (gain.gain > avgGain && gain.gainRatio > gainRatio) { + if (gainRatio == -1) { + break; + } + if (gain.gain >= avgGain && (gain.gainRatio >= gainRatio || gain.gainRatio == -1)) { gainRatio = gain.gainRatio; key = entry.getKey(); } } + node.key = key; List nodeList = nodeMap.get(key); for (int j = 0; j < nodeList.size(); j++) { Node node1 = nodeList.get(j); @@ -181,14 +188,16 @@ public class Tree {//决策树 public void study() throws Exception { if (dataTable != null) { endList = dataTable.getTable().get(dataTable.getKey()); - Node node = new Node(); - node.attribute = dataTable.getKeyType();//当前可用属性 + Set set = dataTable.getKeyType(); + set.remove(dataTable.getKey()); + rootNode.attribute = set;//当前可用属性 List list = new ArrayList<>(); for (int i = 0; i < endList.size(); i++) { list.add(i); } - node.fatherList = list;//当前父级样本 - createNode(node); + rootNode.fatherList = list;//当前父级样本 + List nodeList = createNode(rootNode); + rootNode.nodeList = nodeList; } else { throw new Exception("dataTable is null"); } diff --git a/src/test/java/org/wlld/Food.java b/src/test/java/org/wlld/Food.java index 78b817c..bdd0e4f 100644 --- a/src/test/java/org/wlld/Food.java +++ b/src/test/java/org/wlld/Food.java @@ -6,22 +6,31 @@ package org.wlld; * @date 8:11 上午 2020/2/18 */ public class Food { - private int foodId; - private double testId; + private int height;//身高 + private int weight;//体重 + private int sex;//性别 1男 2女 - public int getFoodId() { - return foodId; + public int getHeight() { + return height; } - public void setFoodId(int foodId) { - this.foodId = foodId; + public void setHeight(int height) { + this.height = height; } - public double getTestId() { - return testId; + public int getWeight() { + return weight; } - public void setTestId(double testId) { - this.testId = testId; + public void setWeight(int weight) { + this.weight = weight; + } + + public int getSex() { + return sex; + } + + public void setSex(int sex) { + this.sex = sex; } } diff --git a/src/test/java/org/wlld/MatrixTest.java b/src/test/java/org/wlld/MatrixTest.java index 35a0aa9..8f6d113 100644 --- a/src/test/java/org/wlld/MatrixTest.java +++ b/src/test/java/org/wlld/MatrixTest.java @@ -2,7 +2,10 @@ package org.wlld; import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.randomForest.DataTable; +import org.wlld.randomForest.Tree; +import java.awt.*; import java.util.*; /** @@ -12,18 +15,31 @@ import java.util.*; */ public class MatrixTest { public static void main(String[] args) throws Exception { - Map map = new TreeMap<>(); - map.put(3.0, "a"); - map.put(2.0, "b"); - map.put(4.0, "c"); - map.put(5.0, "d"); - map.put(1.0, "e"); - for (Map.Entry entry : map.entrySet()) { - System.out.println(entry.getKey()); - } + test4(); + } + public static void test4() throws Exception { + Set column = new HashSet<>(); + column.add("height"); + column.add("weight"); + column.add("sex"); + DataTable dataTable = new DataTable(column); + dataTable.setKey("sex"); + Random random = new Random(); + for (int i = 0; i < 10; i++) { + Food food = new Food(); + food.setHeight(random.nextInt(2)); + food.setWeight(random.nextInt(2)); + food.setSex(random.nextInt(2)); + System.out.println("index==" + i + ",height==" + food.getHeight() + + ",weight==" + food.getWeight() + ",sex==" + food.getSex()); + dataTable.insert(food); + } + Tree tree = new Tree(dataTable); + tree.study(); } + public static void test3() throws Exception { Matrix matrix = new Matrix(4, 3); Matrix matrixY = new Matrix(4, 1);