后剪枝添加完成

pull/1/head
lidapeng 5 years ago
parent f5b2b25d91
commit aa1f5a24a4

@ -14,15 +14,17 @@ public class Tree {//决策树
private Map<String, List<Integer>> table;//总样本
private Node rootNode = new Node();//根节点
private List<Integer> endList;//最终结果分类
private List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
private class Node {
private boolean isEnd = false;
private boolean isEnd = false;//是否是最底层
private List<Integer> fatherList;//父级样本
private Set<String> attribute;//当前可用属性
private String key;//该节点分类属性
private int typeId;//该节点该属性分类的Id值
private List<Node> nodeList;//下属节点
private int type;
private int type;//最底层的类别
private Node fatherNode;//父级节点
}
private class Gain {
@ -142,6 +144,9 @@ public class Tree {//决策树
}
node.key = key;
List<Node> nodeList = nodeMap.get(key);
for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系
nodeList.get(j).fatherNode = node;
}
for (int j = 0; j < nodeList.size(); j++) {
Node node1 = nodeList.get(j);
node1.nodeList = createNode(node1);
@ -151,6 +156,7 @@ public class Tree {//决策树
//判断类别
node.isEnd = true;//叶子节点
node.type = getType(fatherList);
lastNodes.add(node);//将全部最后一层节点集合
return null;
}
}
@ -187,6 +193,10 @@ public class Tree {//决策树
return attriBute;
}
public void judge() {//进行类别判断
}
public void study() throws Exception {
if (dataTable != null) {
endList = dataTable.getTable().get(dataTable.getKey());
@ -200,8 +210,68 @@ public class Tree {//决策树
rootNode.fatherList = list;//当前父级样本
List<Node> nodeList = createNode(rootNode);
rootNode.nodeList = nodeList;
//进行后剪枝
for (Node lastNode : lastNodes) {
prune(lastNode.fatherNode);
}
lastNodes.clear();
} else {
throw new Exception("dataTable is null");
}
}
private void prune(Node node) {//执行剪枝
if (node != null && !node.isEnd) {
List<Node> listNode = node.nodeList;//子节点
if (isPrune(node, listNode)) {//剪枝
deduction(node);
prune(node.fatherNode);
}
}
}
private void deduction(Node node) {
node.isEnd = true;
node.nodeList = null;
node.type = getType(node.fatherList);
}
private boolean isPrune(Node father, List<Node> sonNodes) {
boolean isRemove = false;
List<Integer> typeList = new ArrayList<>();
for (int i = 0; i < sonNodes.size(); i++) {
Node node = sonNodes.get(i);
List<Integer> list = node.fatherList;
typeList.add(getType(list));
}
int fatherType = getType(father.fatherList);
int nub = getRightPoint(father.fatherList, fatherType);
//父级该样本正确率
double rightFather = ArithUtil.div(nub, father.fatherList.size());
int rightNub = 0;
int rightAllNub = 0;
for (int i = 0; i < sonNodes.size(); i++) {
Node node = sonNodes.get(i);
List<Integer> list = node.fatherList;
int right = getRightPoint(list, typeList.get(i));
rightNub = rightNub + right;
rightAllNub = rightAllNub + list.size();
}
double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率
if (rightPoint <= rightFather) {
isRemove = true;
}
return isRemove;
}
private int getRightPoint(List<Integer> types, int type) {
int nub = 0;
for (int index : types) {
int end = endList.get(index);
if (end == type) {
nub++;
}
}
return nub;
}
}

@ -8,8 +8,26 @@ package org.wlld;
public class Food {
private int height;//身高
private int weight;//体重
private int h1;
private int h2;
private int sex;//性别 1男 2女
public int getH1() {
return h1;
}
public void setH1(int h1) {
this.h1 = h1;
}
public int getH2() {
return h2;
}
public void setH2(int h2) {
this.h2 = h2;
}
public int getHeight() {
return height;
}

@ -23,16 +23,21 @@ public class MatrixTest {
column.add("height");
column.add("weight");
column.add("sex");
column.add("h1");
column.add("h2");
DataTable dataTable = new DataTable(column);
dataTable.setKey("sex");
Random random = new Random();
for (int i = 0; i < 10; i++) {
int cla = 3;
for (int i = 0; i < 50; i++) {
Food food = new Food();
food.setHeight(random.nextInt(2));
food.setWeight(random.nextInt(2));
food.setSex(random.nextInt(2));
System.out.println("index==" + i + ",height==" + food.getHeight() +
",weight==" + food.getWeight() + ",sex==" + food.getSex());
food.setHeight(random.nextInt(cla));
food.setWeight(random.nextInt(cla));
food.setSex(random.nextInt(cla));
food.setH1(random.nextInt(cla));
food.setH2(random.nextInt(cla));
// System.out.println("index==" + i + ",height==" + food.getHeight() +
// ",weight==" + food.getWeight() + ",sex==" + food.getSex());
dataTable.insert(food);
}
Tree tree = new Tree(dataTable);

Loading…
Cancel
Save