|
|
@ -16,15 +16,17 @@ public class Tree {//决策树
|
|
|
|
private List<Integer> endList;//最终结果分类
|
|
|
|
private List<Integer> endList;//最终结果分类
|
|
|
|
|
|
|
|
|
|
|
|
private class Node {
|
|
|
|
private class Node {
|
|
|
|
private Map<String, List<Integer>> fatherTable;//父级样本
|
|
|
|
private boolean isEnd = false;
|
|
|
|
|
|
|
|
private List<Integer> fatherList;//父级样本
|
|
|
|
private Set<String> attribute;//当前可用属性
|
|
|
|
private Set<String> attribute;//当前可用属性
|
|
|
|
private double Ent;//信息熵
|
|
|
|
private String key;//该节点分类属性
|
|
|
|
private List<Node> nodeList;//下属节点
|
|
|
|
private List<Node> nodeList;//下属节点
|
|
|
|
|
|
|
|
private int type;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private class Gain {
|
|
|
|
private class Gain {
|
|
|
|
private double gain;
|
|
|
|
private double gain;
|
|
|
|
private double IV;
|
|
|
|
private double gainRatio;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Tree(DataTable dataTable) throws Exception {
|
|
|
|
Tree(DataTable dataTable) throws Exception {
|
|
|
@ -64,48 +66,116 @@ public class Tree {//决策树
|
|
|
|
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
|
|
|
|
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private Gain getGainNode(List<Integer> dataBodyList, double fatherEnt) {
|
|
|
|
private List<Node> createNode(Node node) {
|
|
|
|
Map<Integer, List<Integer>> map = new HashMap<>();
|
|
|
|
Set<String> attributes = node.attribute;
|
|
|
|
int fatherNub = dataBodyList.size();//总样本数
|
|
|
|
List<Integer> fatherList = node.fatherList;
|
|
|
|
double gain = 0;//信息增益
|
|
|
|
if (attributes.size() > 0) {
|
|
|
|
double IV = 0;//增益率
|
|
|
|
Map<String, Map<Integer, List<Integer>>> mapAll = new HashMap<>();
|
|
|
|
//该属性每个离散数据分类的集合
|
|
|
|
double fatherEnt = getEnt(fatherList);
|
|
|
|
for (int i = 0; i < dataBodyList.size(); i++) {
|
|
|
|
int fatherNub = fatherList.size();//总样本数
|
|
|
|
int classification = dataBodyList.get(i);//当前属性
|
|
|
|
//该属性每个离散数据分类的集合
|
|
|
|
if (map.containsKey(classification)) {
|
|
|
|
for (int i = 0; i < fatherList.size(); i++) {
|
|
|
|
List<Integer> list = map.get(classification);
|
|
|
|
int index = fatherList.get(i);//编号
|
|
|
|
list.add(i);
|
|
|
|
for (String attr : attributes) {
|
|
|
|
|
|
|
|
if (!mapAll.containsKey(attr)) {
|
|
|
|
|
|
|
|
mapAll.put(attr, new HashMap<>());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
Map<Integer, List<Integer>> map = mapAll.get(attr);
|
|
|
|
|
|
|
|
int attrValue = table.get(attr).get(index);
|
|
|
|
|
|
|
|
if (!map.containsKey(attrValue)) {
|
|
|
|
|
|
|
|
map.put(attrValue, new ArrayList<>());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
List<Integer> list = map.get(attrValue);
|
|
|
|
|
|
|
|
list.add(index);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
Map<String, List<Node>> nodeMap = new HashMap<>();
|
|
|
|
|
|
|
|
int i = 0;
|
|
|
|
|
|
|
|
double sigmaG = 0;
|
|
|
|
|
|
|
|
Map<String, Gain> gainMap = new HashMap<>();
|
|
|
|
|
|
|
|
for (Map.Entry<String, Map<Integer, List<Integer>>> mapEntry : mapAll.entrySet()) {
|
|
|
|
|
|
|
|
Map<Integer, List<Integer>> map = mapEntry.getValue();
|
|
|
|
|
|
|
|
//求信息增益
|
|
|
|
|
|
|
|
double gain = 0;//信息增益
|
|
|
|
|
|
|
|
double IV = 0;//增益率
|
|
|
|
|
|
|
|
List<Node> nodeList = new ArrayList<>();
|
|
|
|
|
|
|
|
String name = mapEntry.getKey();
|
|
|
|
|
|
|
|
nodeMap.put(name, nodeList);
|
|
|
|
|
|
|
|
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
|
|
|
|
|
|
|
|
Set<String> nowAttribute = removeAttribute(attributes, name);
|
|
|
|
|
|
|
|
Node sonNode = new Node();
|
|
|
|
|
|
|
|
nodeList.add(sonNode);
|
|
|
|
|
|
|
|
sonNode.key = mapEntry.getKey();
|
|
|
|
|
|
|
|
sonNode.attribute = nowAttribute;
|
|
|
|
|
|
|
|
List<Integer> list = entry.getValue();
|
|
|
|
|
|
|
|
sonNode.fatherList = list;
|
|
|
|
|
|
|
|
int myNub = list.size();
|
|
|
|
|
|
|
|
double ent = getEnt(list);//每一个信息熵都是一个子集
|
|
|
|
|
|
|
|
double dNub = ArithUtil.div(myNub, fatherNub);
|
|
|
|
|
|
|
|
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
|
|
|
|
|
|
|
|
gain = getGain(ent, dNub, gain);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
Gain gain1 = new Gain();
|
|
|
|
|
|
|
|
gainMap.put(name, gain1);
|
|
|
|
|
|
|
|
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
|
|
|
|
|
|
|
|
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
|
|
|
|
|
|
|
|
sigmaG = ArithUtil.add(gain1.gain, sigmaG);
|
|
|
|
|
|
|
|
i++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
double avgGain = ArithUtil.div(sigmaG, i);
|
|
|
|
|
|
|
|
double gainRatio = 0;//最大增益率
|
|
|
|
|
|
|
|
String key = null;//可选属性
|
|
|
|
|
|
|
|
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
|
|
|
|
|
|
|
|
Gain gain = entry.getValue();
|
|
|
|
|
|
|
|
if (gain.gain > avgGain && gain.gainRatio > gainRatio) {
|
|
|
|
|
|
|
|
gainRatio = gain.gainRatio;
|
|
|
|
|
|
|
|
key = entry.getKey();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
List<Node> nodeList = nodeMap.get(key);
|
|
|
|
|
|
|
|
for (int j = 0; j < nodeList.size(); j++) {
|
|
|
|
|
|
|
|
Node node1 = nodeList.get(j);
|
|
|
|
|
|
|
|
node1.nodeList = createNode(node1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return nodeList;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
//判断类别
|
|
|
|
|
|
|
|
node.isEnd = true;
|
|
|
|
|
|
|
|
node.type = getType(fatherList);
|
|
|
|
|
|
|
|
return null;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private int getType(List<Integer> list) {
|
|
|
|
|
|
|
|
Map<Integer, Integer> myType = new HashMap<>();
|
|
|
|
|
|
|
|
for (int index : list) {
|
|
|
|
|
|
|
|
int type = endList.get(index);//最终结果的类别
|
|
|
|
|
|
|
|
if (myType.containsKey(type)) {
|
|
|
|
|
|
|
|
myType.put(type, myType.get(type) + 1);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
List<Integer> list = new ArrayList<>();
|
|
|
|
myType.put(type, 1);
|
|
|
|
list.add(i);
|
|
|
|
|
|
|
|
map.put(classification, list);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
//求信息增益
|
|
|
|
int type = 0;
|
|
|
|
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
|
|
|
|
int nub = 0;
|
|
|
|
List<Integer> list = entry.getValue();
|
|
|
|
for (Map.Entry<Integer, Integer> entry : myType.entrySet()) {
|
|
|
|
int myNub = list.size();
|
|
|
|
int nowNub = entry.getValue();
|
|
|
|
double ent = getEnt(list);//每一个信息熵都是一个子集
|
|
|
|
if (nowNub > nub) {
|
|
|
|
double dNub = ArithUtil.div(myNub, fatherNub);
|
|
|
|
type = entry.getKey();
|
|
|
|
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
|
|
|
|
nub = nowNub;
|
|
|
|
gain = getGain(ent, dNub, gain);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Gain gain1 = new Gain();
|
|
|
|
return type;
|
|
|
|
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
|
|
|
|
|
|
|
|
gain1.IV = -IV;
|
|
|
|
|
|
|
|
return gain1;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private Node createNode(Node node) {
|
|
|
|
private Set<String> removeAttribute(Set<String> attributes, String name) {
|
|
|
|
Map<String, List<Integer>> fatherTable = node.fatherTable;
|
|
|
|
Set<String> attriBute = new HashSet<>();
|
|
|
|
Set<String> attributes = node.attribute;
|
|
|
|
for (String myName : attributes) {
|
|
|
|
double fatherEnt = node.Ent;
|
|
|
|
if (!myName.equals(name)) {
|
|
|
|
for (String name : attributes) {
|
|
|
|
attriBute.add(myName);
|
|
|
|
List<Integer> dataBodyList = fatherTable.get(name);
|
|
|
|
}
|
|
|
|
Gain gain = getGainNode(dataBodyList, fatherEnt);//信息增益
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return null;
|
|
|
|
return attriBute;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public void study() throws Exception {
|
|
|
|
public void study() throws Exception {
|
|
|
@ -113,8 +183,11 @@ public class Tree {//决策树
|
|
|
|
endList = dataTable.getTable().get(dataTable.getKey());
|
|
|
|
endList = dataTable.getTable().get(dataTable.getKey());
|
|
|
|
Node node = new Node();
|
|
|
|
Node node = new Node();
|
|
|
|
node.attribute = dataTable.getKeyType();//当前可用属性
|
|
|
|
node.attribute = dataTable.getKeyType();//当前可用属性
|
|
|
|
node.fatherTable = table;//当前父级样本
|
|
|
|
List<Integer> list = new ArrayList<>();
|
|
|
|
node.Ent = getEnt(endList);
|
|
|
|
for (int i = 0; i < endList.size(); i++) {
|
|
|
|
|
|
|
|
list.add(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
node.fatherList = list;//当前父级样本
|
|
|
|
createNode(node);
|
|
|
|
createNode(node);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
throw new Exception("dataTable is null");
|
|
|
|
throw new Exception("dataTable is null");
|
|
|
|