|
|
|
@ -2,6 +2,7 @@ package org.wlld.randomForest;
|
|
|
|
|
|
|
|
|
|
import org.wlld.tools.ArithUtil;
|
|
|
|
|
|
|
|
|
|
import java.lang.reflect.Method;
|
|
|
|
|
import java.util.*;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
@ -15,6 +16,7 @@ public class Tree {//决策树
|
|
|
|
|
private Node rootNode = new Node();//根节点
|
|
|
|
|
private List<Integer> endList;//最终结果分类
|
|
|
|
|
private List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
|
|
|
|
|
private Random random = new Random();
|
|
|
|
|
|
|
|
|
|
private class Node {
|
|
|
|
|
private boolean isEnd = false;//是否是最底层
|
|
|
|
@ -193,8 +195,37 @@ public class Tree {//决策树
|
|
|
|
|
return attriBute;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public void judge() {//进行类别判断
|
|
|
|
|
private int getTypeId(Object ob, String name) throws Exception {
|
|
|
|
|
Class<?> body = ob.getClass();
|
|
|
|
|
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
|
|
|
|
|
Method method = body.getMethod(methodName);
|
|
|
|
|
return (int) method.invoke(ob);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public int judge(Object ob) throws Exception {//进行类别判断
|
|
|
|
|
return goTree(ob, rootNode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private int goTree(Object ob, Node node) throws Exception {//从树顶向下攀爬
|
|
|
|
|
if (!node.isEnd) {
|
|
|
|
|
int myType = getTypeId(ob, node.key);//当前类别的ID
|
|
|
|
|
List<Node> nodeList = node.nodeList;
|
|
|
|
|
boolean isOk = false;
|
|
|
|
|
for (Node testNode : nodeList) {
|
|
|
|
|
if (testNode.typeId == myType) {
|
|
|
|
|
isOk = true;
|
|
|
|
|
node = testNode;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!isOk) {//当前类别缺失,未知的属性值
|
|
|
|
|
int index = random.nextInt(nodeList.size());
|
|
|
|
|
node = nodeList.get(index);
|
|
|
|
|
}
|
|
|
|
|
return goTree(ob, node);
|
|
|
|
|
} else {
|
|
|
|
|
return node.type;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public void study() throws Exception {
|
|
|
|
|