|
|
@ -14,15 +14,17 @@ public class Tree {//决策树
|
|
|
|
private Map<String, List<Integer>> table;//总样本
|
|
|
|
private Map<String, List<Integer>> table;//总样本
|
|
|
|
private Node rootNode = new Node();//根节点
|
|
|
|
private Node rootNode = new Node();//根节点
|
|
|
|
private List<Integer> endList;//最终结果分类
|
|
|
|
private List<Integer> endList;//最终结果分类
|
|
|
|
|
|
|
|
private List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
|
|
|
|
|
|
|
|
|
|
|
|
private class Node {
|
|
|
|
private class Node {
|
|
|
|
private boolean isEnd = false;
|
|
|
|
private boolean isEnd = false;//是否是最底层
|
|
|
|
private List<Integer> fatherList;//父级样本
|
|
|
|
private List<Integer> fatherList;//父级样本
|
|
|
|
private Set<String> attribute;//当前可用属性
|
|
|
|
private Set<String> attribute;//当前可用属性
|
|
|
|
private String key;//该节点分类属性
|
|
|
|
private String key;//该节点分类属性
|
|
|
|
private int typeId;//该节点该属性分类的Id值
|
|
|
|
private int typeId;//该节点该属性分类的Id值
|
|
|
|
private List<Node> nodeList;//下属节点
|
|
|
|
private List<Node> nodeList;//下属节点
|
|
|
|
private int type;
|
|
|
|
private int type;//最底层的类别
|
|
|
|
|
|
|
|
private Node fatherNode;//父级节点
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private class Gain {
|
|
|
|
private class Gain {
|
|
|
@ -142,6 +144,9 @@ public class Tree {//决策树
|
|
|
|
}
|
|
|
|
}
|
|
|
|
node.key = key;
|
|
|
|
node.key = key;
|
|
|
|
List<Node> nodeList = nodeMap.get(key);
|
|
|
|
List<Node> nodeList = nodeMap.get(key);
|
|
|
|
|
|
|
|
for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系
|
|
|
|
|
|
|
|
nodeList.get(j).fatherNode = node;
|
|
|
|
|
|
|
|
}
|
|
|
|
for (int j = 0; j < nodeList.size(); j++) {
|
|
|
|
for (int j = 0; j < nodeList.size(); j++) {
|
|
|
|
Node node1 = nodeList.get(j);
|
|
|
|
Node node1 = nodeList.get(j);
|
|
|
|
node1.nodeList = createNode(node1);
|
|
|
|
node1.nodeList = createNode(node1);
|
|
|
@ -151,6 +156,7 @@ public class Tree {//决策树
|
|
|
|
//判断类别
|
|
|
|
//判断类别
|
|
|
|
node.isEnd = true;//叶子节点
|
|
|
|
node.isEnd = true;//叶子节点
|
|
|
|
node.type = getType(fatherList);
|
|
|
|
node.type = getType(fatherList);
|
|
|
|
|
|
|
|
lastNodes.add(node);//将全部最后一层节点集合
|
|
|
|
return null;
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -187,6 +193,10 @@ public class Tree {//决策树
|
|
|
|
return attriBute;
|
|
|
|
return attriBute;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public void judge() {//进行类别判断
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public void study() throws Exception {
|
|
|
|
public void study() throws Exception {
|
|
|
|
if (dataTable != null) {
|
|
|
|
if (dataTable != null) {
|
|
|
|
endList = dataTable.getTable().get(dataTable.getKey());
|
|
|
|
endList = dataTable.getTable().get(dataTable.getKey());
|
|
|
@ -200,8 +210,68 @@ public class Tree {//决策树
|
|
|
|
rootNode.fatherList = list;//当前父级样本
|
|
|
|
rootNode.fatherList = list;//当前父级样本
|
|
|
|
List<Node> nodeList = createNode(rootNode);
|
|
|
|
List<Node> nodeList = createNode(rootNode);
|
|
|
|
rootNode.nodeList = nodeList;
|
|
|
|
rootNode.nodeList = nodeList;
|
|
|
|
|
|
|
|
//进行后剪枝
|
|
|
|
|
|
|
|
for (Node lastNode : lastNodes) {
|
|
|
|
|
|
|
|
prune(lastNode.fatherNode);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
lastNodes.clear();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
throw new Exception("dataTable is null");
|
|
|
|
throw new Exception("dataTable is null");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private void prune(Node node) {//执行剪枝
|
|
|
|
|
|
|
|
if (node != null && !node.isEnd) {
|
|
|
|
|
|
|
|
List<Node> listNode = node.nodeList;//子节点
|
|
|
|
|
|
|
|
if (isPrune(node, listNode)) {//剪枝
|
|
|
|
|
|
|
|
deduction(node);
|
|
|
|
|
|
|
|
prune(node.fatherNode);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private void deduction(Node node) {
|
|
|
|
|
|
|
|
node.isEnd = true;
|
|
|
|
|
|
|
|
node.nodeList = null;
|
|
|
|
|
|
|
|
node.type = getType(node.fatherList);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private boolean isPrune(Node father, List<Node> sonNodes) {
|
|
|
|
|
|
|
|
boolean isRemove = false;
|
|
|
|
|
|
|
|
List<Integer> typeList = new ArrayList<>();
|
|
|
|
|
|
|
|
for (int i = 0; i < sonNodes.size(); i++) {
|
|
|
|
|
|
|
|
Node node = sonNodes.get(i);
|
|
|
|
|
|
|
|
List<Integer> list = node.fatherList;
|
|
|
|
|
|
|
|
typeList.add(getType(list));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
int fatherType = getType(father.fatherList);
|
|
|
|
|
|
|
|
int nub = getRightPoint(father.fatherList, fatherType);
|
|
|
|
|
|
|
|
//父级该样本正确率
|
|
|
|
|
|
|
|
double rightFather = ArithUtil.div(nub, father.fatherList.size());
|
|
|
|
|
|
|
|
int rightNub = 0;
|
|
|
|
|
|
|
|
int rightAllNub = 0;
|
|
|
|
|
|
|
|
for (int i = 0; i < sonNodes.size(); i++) {
|
|
|
|
|
|
|
|
Node node = sonNodes.get(i);
|
|
|
|
|
|
|
|
List<Integer> list = node.fatherList;
|
|
|
|
|
|
|
|
int right = getRightPoint(list, typeList.get(i));
|
|
|
|
|
|
|
|
rightNub = rightNub + right;
|
|
|
|
|
|
|
|
rightAllNub = rightAllNub + list.size();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率
|
|
|
|
|
|
|
|
if (rightPoint <= rightFather) {
|
|
|
|
|
|
|
|
isRemove = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return isRemove;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private int getRightPoint(List<Integer> types, int type) {
|
|
|
|
|
|
|
|
int nub = 0;
|
|
|
|
|
|
|
|
for (int index : types) {
|
|
|
|
|
|
|
|
int end = endList.get(index);
|
|
|
|
|
|
|
|
if (end == type) {
|
|
|
|
|
|
|
|
nub++;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return nub;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|