决策树生成完毕,还没有做剪枝

pull/1/head
lidapeng 5 years ago
parent 63b53a5089
commit 57849d8543

@ -12,7 +12,7 @@ import java.util.*;
public class Tree {//决策树
private DataTable dataTable;
private Map<String, List<Integer>> table;//总样本
private Node rootNode;//根节点
private Node rootNode = new Node();//根节点
private List<Integer> endList;//最终结果分类
private class Node {
@ -29,7 +29,7 @@ public class Tree {//决策树
private double gainRatio;
}
Tree(DataTable dataTable) throws Exception {
public Tree(DataTable dataTable) throws Exception {
if (dataTable.getKey() != null && dataTable.getLength() > 0) {
table = dataTable.getTable();
this.dataTable = dataTable;
@ -105,12 +105,11 @@ public class Tree {//决策树
Set<String> nowAttribute = removeAttribute(attributes, name);
Node sonNode = new Node();
nodeList.add(sonNode);
sonNode.key = mapEntry.getKey();
sonNode.attribute = nowAttribute;
List<Integer> list = entry.getValue();
sonNode.fatherList = list;
int myNub = list.size();
double ent = getEnt(list);//每一个信息熵都是一个子集
double ent = getEnt(list);
double dNub = ArithUtil.div(myNub, fatherNub);
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
gain = getGain(ent, dNub, gain);
@ -118,20 +117,28 @@ public class Tree {//决策树
Gain gain1 = new Gain();
gainMap.put(name, gain1);
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
if (IV != 0) {
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
} else {
gain1.gainRatio = -1;
}
sigmaG = ArithUtil.add(gain1.gain, sigmaG);
i++;
}
double avgGain = ArithUtil.div(sigmaG, i);
double avgGain = sigmaG / i;
double gainRatio = 0;//最大增益率
String key = null;//可选属性
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
Gain gain = entry.getValue();
if (gain.gain > avgGain && gain.gainRatio > gainRatio) {
if (gainRatio == -1) {
break;
}
if (gain.gain >= avgGain && (gain.gainRatio >= gainRatio || gain.gainRatio == -1)) {
gainRatio = gain.gainRatio;
key = entry.getKey();
}
}
node.key = key;
List<Node> nodeList = nodeMap.get(key);
for (int j = 0; j < nodeList.size(); j++) {
Node node1 = nodeList.get(j);
@ -181,14 +188,16 @@ public class Tree {//决策树
public void study() throws Exception {
if (dataTable != null) {
endList = dataTable.getTable().get(dataTable.getKey());
Node node = new Node();
node.attribute = dataTable.getKeyType();//当前可用属性
Set<String> set = dataTable.getKeyType();
set.remove(dataTable.getKey());
rootNode.attribute = set;//当前可用属性
List<Integer> list = new ArrayList<>();
for (int i = 0; i < endList.size(); i++) {
list.add(i);
}
node.fatherList = list;//当前父级样本
createNode(node);
rootNode.fatherList = list;//当前父级样本
List<Node> nodeList = createNode(rootNode);
rootNode.nodeList = nodeList;
} else {
throw new Exception("dataTable is null");
}

@ -6,22 +6,31 @@ package org.wlld;
* @date 8:11 2020/2/18
*/
public class Food {
private int foodId;
private double testId;
private int height;//身高
private int weight;//体重
private int sex;//性别 1男 2女
public int getFoodId() {
return foodId;
public int getHeight() {
return height;
}
public void setFoodId(int foodId) {
this.foodId = foodId;
public void setHeight(int height) {
this.height = height;
}
public double getTestId() {
return testId;
public int getWeight() {
return weight;
}
public void setTestId(double testId) {
this.testId = testId;
public void setWeight(int weight) {
this.weight = weight;
}
public int getSex() {
return sex;
}
public void setSex(int sex) {
this.sex = sex;
}
}

@ -2,7 +2,10 @@ package org.wlld;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.randomForest.DataTable;
import org.wlld.randomForest.Tree;
import java.awt.*;
import java.util.*;
/**
@ -12,18 +15,31 @@ import java.util.*;
*/
public class MatrixTest {
public static void main(String[] args) throws Exception {
Map<Double, String> map = new TreeMap<>();
map.put(3.0, "a");
map.put(2.0, "b");
map.put(4.0, "c");
map.put(5.0, "d");
map.put(1.0, "e");
for (Map.Entry<Double, String> entry : map.entrySet()) {
System.out.println(entry.getKey());
}
test4();
}
public static void test4() throws Exception {
Set<String> column = new HashSet<>();
column.add("height");
column.add("weight");
column.add("sex");
DataTable dataTable = new DataTable(column);
dataTable.setKey("sex");
Random random = new Random();
for (int i = 0; i < 10; i++) {
Food food = new Food();
food.setHeight(random.nextInt(2));
food.setWeight(random.nextInt(2));
food.setSex(random.nextInt(2));
System.out.println("index==" + i + ",height==" + food.getHeight() +
",weight==" + food.getWeight() + ",sex==" + food.getSex());
dataTable.insert(food);
}
Tree tree = new Tree(dataTable);
tree.study();
}
public static void test3() throws Exception {
Matrix matrix = new Matrix(4, 3);
Matrix matrixY = new Matrix(4, 1);

Loading…
Cancel
Save