后剪枝添加完成

pull/1/head
lidapeng 5 years ago
parent f5b2b25d91
commit aa1f5a24a4

@ -14,15 +14,17 @@ public class Tree {//决策树
private Map<String, List<Integer>> table;//总样本 private Map<String, List<Integer>> table;//总样本
private Node rootNode = new Node();//根节点 private Node rootNode = new Node();//根节点
private List<Integer> endList;//最终结果分类 private List<Integer> endList;//最终结果分类
private List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
private class Node { private class Node {
private boolean isEnd = false; private boolean isEnd = false;//是否是最底层
private List<Integer> fatherList;//父级样本 private List<Integer> fatherList;//父级样本
private Set<String> attribute;//当前可用属性 private Set<String> attribute;//当前可用属性
private String key;//该节点分类属性 private String key;//该节点分类属性
private int typeId;//该节点该属性分类的Id值 private int typeId;//该节点该属性分类的Id值
private List<Node> nodeList;//下属节点 private List<Node> nodeList;//下属节点
private int type; private int type;//最底层的类别
private Node fatherNode;//父级节点
} }
private class Gain { private class Gain {
@ -142,6 +144,9 @@ public class Tree {//决策树
} }
node.key = key; node.key = key;
List<Node> nodeList = nodeMap.get(key); List<Node> nodeList = nodeMap.get(key);
for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系
nodeList.get(j).fatherNode = node;
}
for (int j = 0; j < nodeList.size(); j++) { for (int j = 0; j < nodeList.size(); j++) {
Node node1 = nodeList.get(j); Node node1 = nodeList.get(j);
node1.nodeList = createNode(node1); node1.nodeList = createNode(node1);
@ -151,6 +156,7 @@ public class Tree {//决策树
//判断类别 //判断类别
node.isEnd = true;//叶子节点 node.isEnd = true;//叶子节点
node.type = getType(fatherList); node.type = getType(fatherList);
lastNodes.add(node);//将全部最后一层节点集合
return null; return null;
} }
} }
@ -187,6 +193,10 @@ public class Tree {//决策树
return attriBute; return attriBute;
} }
public void judge() {//进行类别判断
}
public void study() throws Exception { public void study() throws Exception {
if (dataTable != null) { if (dataTable != null) {
endList = dataTable.getTable().get(dataTable.getKey()); endList = dataTable.getTable().get(dataTable.getKey());
@ -200,8 +210,68 @@ public class Tree {//决策树
rootNode.fatherList = list;//当前父级样本 rootNode.fatherList = list;//当前父级样本
List<Node> nodeList = createNode(rootNode); List<Node> nodeList = createNode(rootNode);
rootNode.nodeList = nodeList; rootNode.nodeList = nodeList;
//进行后剪枝
for (Node lastNode : lastNodes) {
prune(lastNode.fatherNode);
}
lastNodes.clear();
} else { } else {
throw new Exception("dataTable is null"); throw new Exception("dataTable is null");
} }
} }
private void prune(Node node) {//执行剪枝
if (node != null && !node.isEnd) {
List<Node> listNode = node.nodeList;//子节点
if (isPrune(node, listNode)) {//剪枝
deduction(node);
prune(node.fatherNode);
}
}
}
private void deduction(Node node) {
node.isEnd = true;
node.nodeList = null;
node.type = getType(node.fatherList);
}
private boolean isPrune(Node father, List<Node> sonNodes) {
boolean isRemove = false;
List<Integer> typeList = new ArrayList<>();
for (int i = 0; i < sonNodes.size(); i++) {
Node node = sonNodes.get(i);
List<Integer> list = node.fatherList;
typeList.add(getType(list));
}
int fatherType = getType(father.fatherList);
int nub = getRightPoint(father.fatherList, fatherType);
//父级该样本正确率
double rightFather = ArithUtil.div(nub, father.fatherList.size());
int rightNub = 0;
int rightAllNub = 0;
for (int i = 0; i < sonNodes.size(); i++) {
Node node = sonNodes.get(i);
List<Integer> list = node.fatherList;
int right = getRightPoint(list, typeList.get(i));
rightNub = rightNub + right;
rightAllNub = rightAllNub + list.size();
}
double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率
if (rightPoint <= rightFather) {
isRemove = true;
}
return isRemove;
}
private int getRightPoint(List<Integer> types, int type) {
int nub = 0;
for (int index : types) {
int end = endList.get(index);
if (end == type) {
nub++;
}
}
return nub;
}
} }

@ -8,8 +8,26 @@ package org.wlld;
public class Food { public class Food {
private int height;//身高 private int height;//身高
private int weight;//体重 private int weight;//体重
private int h1;
private int h2;
private int sex;//性别 1男 2女 private int sex;//性别 1男 2女
public int getH1() {
return h1;
}
public void setH1(int h1) {
this.h1 = h1;
}
public int getH2() {
return h2;
}
public void setH2(int h2) {
this.h2 = h2;
}
public int getHeight() { public int getHeight() {
return height; return height;
} }

@ -23,16 +23,21 @@ public class MatrixTest {
column.add("height"); column.add("height");
column.add("weight"); column.add("weight");
column.add("sex"); column.add("sex");
column.add("h1");
column.add("h2");
DataTable dataTable = new DataTable(column); DataTable dataTable = new DataTable(column);
dataTable.setKey("sex"); dataTable.setKey("sex");
Random random = new Random(); Random random = new Random();
for (int i = 0; i < 10; i++) { int cla = 3;
for (int i = 0; i < 50; i++) {
Food food = new Food(); Food food = new Food();
food.setHeight(random.nextInt(2)); food.setHeight(random.nextInt(cla));
food.setWeight(random.nextInt(2)); food.setWeight(random.nextInt(cla));
food.setSex(random.nextInt(2)); food.setSex(random.nextInt(cla));
System.out.println("index==" + i + ",height==" + food.getHeight() + food.setH1(random.nextInt(cla));
",weight==" + food.getWeight() + ",sex==" + food.getSex()); food.setH2(random.nextInt(cla));
// System.out.println("index==" + i + ",height==" + food.getHeight() +
// ",weight==" + food.getWeight() + ",sex==" + food.getSex());
dataTable.insert(food); dataTable.insert(food);
} }
Tree tree = new Tree(dataTable); Tree tree = new Tree(dataTable);

Loading…
Cancel
Save