From aa1f5a24a4d68a650f518e21c531cd6d4413f540 Mon Sep 17 00:00:00 2001 From: lidapeng Date: Sat, 22 Feb 2020 14:09:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8E=E5=89=AA=E6=9E=9D=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/main/java/org/wlld/randomForest/Tree.java | 74 ++++++++++++++++++- src/test/java/org/wlld/Food.java | 18 +++++ src/test/java/org/wlld/MatrixTest.java | 17 +++-- 3 files changed, 101 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/wlld/randomForest/Tree.java b/src/main/java/org/wlld/randomForest/Tree.java index 412fa69..18e3955 100644 --- a/src/main/java/org/wlld/randomForest/Tree.java +++ b/src/main/java/org/wlld/randomForest/Tree.java @@ -14,15 +14,17 @@ public class Tree {//决策树 private Map> table;//总样本 private Node rootNode = new Node();//根节点 private List endList;//最终结果分类 + private List lastNodes = new ArrayList<>();//最后一层节点集合 private class Node { - private boolean isEnd = false; + private boolean isEnd = false;//是否是最底层 private List fatherList;//父级样本 private Set attribute;//当前可用属性 private String key;//该节点分类属性 private int typeId;//该节点该属性分类的Id值 private List nodeList;//下属节点 - private int type; + private int type;//最底层的类别 + private Node fatherNode;//父级节点 } private class Gain { @@ -142,6 +144,9 @@ public class Tree {//决策树 } node.key = key; List nodeList = nodeMap.get(key); + for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系 + nodeList.get(j).fatherNode = node; + } for (int j = 0; j < nodeList.size(); j++) { Node node1 = nodeList.get(j); node1.nodeList = createNode(node1); @@ -151,6 +156,7 @@ public class Tree {//决策树 //判断类别 node.isEnd = true;//叶子节点 node.type = getType(fatherList); + lastNodes.add(node);//将全部最后一层节点集合 return null; } } @@ -187,6 +193,10 @@ public class Tree {//决策树 return attriBute; } + public void judge() {//进行类别判断 + + } + public void study() throws Exception { if (dataTable != null) { endList = dataTable.getTable().get(dataTable.getKey()); @@ -200,8 +210,68 @@ public class Tree {//决策树 rootNode.fatherList = list;//当前父级样本 List nodeList = createNode(rootNode); rootNode.nodeList = nodeList; + //进行后剪枝 + for (Node lastNode : lastNodes) { + prune(lastNode.fatherNode); + } + lastNodes.clear(); } else { throw new Exception("dataTable is null"); } } + + private void prune(Node node) {//执行剪枝 + if (node != null && !node.isEnd) { + List listNode = node.nodeList;//子节点 + if (isPrune(node, listNode)) {//剪枝 + deduction(node); + prune(node.fatherNode); + } + } + } + + private void deduction(Node node) { + node.isEnd = true; + node.nodeList = null; + node.type = getType(node.fatherList); + } + + private boolean isPrune(Node father, List sonNodes) { + boolean isRemove = false; + List typeList = new ArrayList<>(); + for (int i = 0; i < sonNodes.size(); i++) { + Node node = sonNodes.get(i); + List list = node.fatherList; + typeList.add(getType(list)); + } + int fatherType = getType(father.fatherList); + int nub = getRightPoint(father.fatherList, fatherType); + //父级该样本正确率 + double rightFather = ArithUtil.div(nub, father.fatherList.size()); + int rightNub = 0; + int rightAllNub = 0; + for (int i = 0; i < sonNodes.size(); i++) { + Node node = sonNodes.get(i); + List list = node.fatherList; + int right = getRightPoint(list, typeList.get(i)); + rightNub = rightNub + right; + rightAllNub = rightAllNub + list.size(); + } + double rightPoint = ArithUtil.div(rightNub, rightAllNub);//子节点正确率 + if (rightPoint <= rightFather) { + isRemove = true; + } + return isRemove; + } + + private int getRightPoint(List types, int type) { + int nub = 0; + for (int index : types) { + int end = endList.get(index); + if (end == type) { + nub++; + } + } + return nub; + } } diff --git a/src/test/java/org/wlld/Food.java b/src/test/java/org/wlld/Food.java index bdd0e4f..e794f8c 100644 --- a/src/test/java/org/wlld/Food.java +++ b/src/test/java/org/wlld/Food.java @@ -8,8 +8,26 @@ package org.wlld; public class Food { private int height;//身高 private int weight;//体重 + private int h1; + private int h2; private int sex;//性别 1男 2女 + public int getH1() { + return h1; + } + + public void setH1(int h1) { + this.h1 = h1; + } + + public int getH2() { + return h2; + } + + public void setH2(int h2) { + this.h2 = h2; + } + public int getHeight() { return height; } diff --git a/src/test/java/org/wlld/MatrixTest.java b/src/test/java/org/wlld/MatrixTest.java index 8f6d113..a8c8105 100644 --- a/src/test/java/org/wlld/MatrixTest.java +++ b/src/test/java/org/wlld/MatrixTest.java @@ -23,16 +23,21 @@ public class MatrixTest { column.add("height"); column.add("weight"); column.add("sex"); + column.add("h1"); + column.add("h2"); DataTable dataTable = new DataTable(column); dataTable.setKey("sex"); Random random = new Random(); - for (int i = 0; i < 10; i++) { + int cla = 3; + for (int i = 0; i < 50; i++) { Food food = new Food(); - food.setHeight(random.nextInt(2)); - food.setWeight(random.nextInt(2)); - food.setSex(random.nextInt(2)); - System.out.println("index==" + i + ",height==" + food.getHeight() + - ",weight==" + food.getWeight() + ",sex==" + food.getSex()); + food.setHeight(random.nextInt(cla)); + food.setWeight(random.nextInt(cla)); + food.setSex(random.nextInt(cla)); + food.setH1(random.nextInt(cla)); + food.setH2(random.nextInt(cla)); +// System.out.println("index==" + i + ",height==" + food.getHeight() + +// ",weight==" + food.getWeight() + ",sex==" + food.getSex()); dataTable.insert(food); } Tree tree = new Tree(dataTable);