|
|
|
@ -12,7 +12,7 @@ import java.util.*;
|
|
|
|
|
public class Tree {//决策树
|
|
|
|
|
private DataTable dataTable;
|
|
|
|
|
private Map<String, List<Integer>> table;//总样本
|
|
|
|
|
private Node rootNode;//根节点
|
|
|
|
|
private Node rootNode = new Node();//根节点
|
|
|
|
|
private List<Integer> endList;//最终结果分类
|
|
|
|
|
|
|
|
|
|
private class Node {
|
|
|
|
@ -29,7 +29,7 @@ public class Tree {//决策树
|
|
|
|
|
private double gainRatio;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tree(DataTable dataTable) throws Exception {
|
|
|
|
|
public Tree(DataTable dataTable) throws Exception {
|
|
|
|
|
if (dataTable.getKey() != null && dataTable.getLength() > 0) {
|
|
|
|
|
table = dataTable.getTable();
|
|
|
|
|
this.dataTable = dataTable;
|
|
|
|
@ -105,12 +105,11 @@ public class Tree {//决策树
|
|
|
|
|
Set<String> nowAttribute = removeAttribute(attributes, name);
|
|
|
|
|
Node sonNode = new Node();
|
|
|
|
|
nodeList.add(sonNode);
|
|
|
|
|
sonNode.key = mapEntry.getKey();
|
|
|
|
|
sonNode.attribute = nowAttribute;
|
|
|
|
|
List<Integer> list = entry.getValue();
|
|
|
|
|
sonNode.fatherList = list;
|
|
|
|
|
int myNub = list.size();
|
|
|
|
|
double ent = getEnt(list);//每一个信息熵都是一个子集
|
|
|
|
|
double ent = getEnt(list);
|
|
|
|
|
double dNub = ArithUtil.div(myNub, fatherNub);
|
|
|
|
|
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
|
|
|
|
|
gain = getGain(ent, dNub, gain);
|
|
|
|
@ -118,20 +117,28 @@ public class Tree {//决策树
|
|
|
|
|
Gain gain1 = new Gain();
|
|
|
|
|
gainMap.put(name, gain1);
|
|
|
|
|
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
|
|
|
|
|
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
|
|
|
|
|
if (IV != 0) {
|
|
|
|
|
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
|
|
|
|
|
} else {
|
|
|
|
|
gain1.gainRatio = -1;
|
|
|
|
|
}
|
|
|
|
|
sigmaG = ArithUtil.add(gain1.gain, sigmaG);
|
|
|
|
|
i++;
|
|
|
|
|
}
|
|
|
|
|
double avgGain = ArithUtil.div(sigmaG, i);
|
|
|
|
|
double avgGain = sigmaG / i;
|
|
|
|
|
double gainRatio = 0;//最大增益率
|
|
|
|
|
String key = null;//可选属性
|
|
|
|
|
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
|
|
|
|
|
Gain gain = entry.getValue();
|
|
|
|
|
if (gain.gain > avgGain && gain.gainRatio > gainRatio) {
|
|
|
|
|
if (gainRatio == -1) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
if (gain.gain >= avgGain && (gain.gainRatio >= gainRatio || gain.gainRatio == -1)) {
|
|
|
|
|
gainRatio = gain.gainRatio;
|
|
|
|
|
key = entry.getKey();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
node.key = key;
|
|
|
|
|
List<Node> nodeList = nodeMap.get(key);
|
|
|
|
|
for (int j = 0; j < nodeList.size(); j++) {
|
|
|
|
|
Node node1 = nodeList.get(j);
|
|
|
|
@ -181,14 +188,16 @@ public class Tree {//决策树
|
|
|
|
|
public void study() throws Exception {
|
|
|
|
|
if (dataTable != null) {
|
|
|
|
|
endList = dataTable.getTable().get(dataTable.getKey());
|
|
|
|
|
Node node = new Node();
|
|
|
|
|
node.attribute = dataTable.getKeyType();//当前可用属性
|
|
|
|
|
Set<String> set = dataTable.getKeyType();
|
|
|
|
|
set.remove(dataTable.getKey());
|
|
|
|
|
rootNode.attribute = set;//当前可用属性
|
|
|
|
|
List<Integer> list = new ArrayList<>();
|
|
|
|
|
for (int i = 0; i < endList.size(); i++) {
|
|
|
|
|
list.add(i);
|
|
|
|
|
}
|
|
|
|
|
node.fatherList = list;//当前父级样本
|
|
|
|
|
createNode(node);
|
|
|
|
|
rootNode.fatherList = list;//当前父级样本
|
|
|
|
|
List<Node> nodeList = createNode(rootNode);
|
|
|
|
|
rootNode.nodeList = nodeList;
|
|
|
|
|
} else {
|
|
|
|
|
throw new Exception("dataTable is null");
|
|
|
|
|
}
|
|
|
|
|