parent
6a0ba2b7f6
commit
4897a38e68
@ -0,0 +1,68 @@
|
||||
package org.wlld.randomForest;
|
||||
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @description 内存数据表
|
||||
* @date 3:48 下午 2020/2/17
|
||||
*/
|
||||
public class DataTable {//数据表
|
||||
private Map<String, List<Integer>> table = new HashMap<>();
|
||||
private Set<String> keyType;//表的属性
|
||||
private String key;//最终分类字段
|
||||
private int length;
|
||||
|
||||
public String getKey() {
|
||||
return key;
|
||||
}
|
||||
|
||||
public int getLength() {
|
||||
return length;
|
||||
}
|
||||
|
||||
public Map<String, List<Integer>> getTable() {
|
||||
return table;
|
||||
}
|
||||
|
||||
public Set<String> getKeyType() {
|
||||
return keyType;
|
||||
}
|
||||
|
||||
public void setKey(String key) throws Exception {
|
||||
if (keyType.contains(key)) {
|
||||
this.key = key;
|
||||
} else {
|
||||
throw new Exception("NOT FIND KEY");
|
||||
}
|
||||
}
|
||||
|
||||
public DataTable(Set<String> key) {
|
||||
this.keyType = key;
|
||||
for (String name : key) {
|
||||
table.put(name, new ArrayList<>());
|
||||
}
|
||||
}
|
||||
|
||||
public void insert(Object ob) {
|
||||
try {
|
||||
Class<?> body = ob.getClass();
|
||||
length++;
|
||||
for (String name : keyType) {
|
||||
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
|
||||
Method method = body.getMethod(methodName);
|
||||
Object dm = method.invoke(ob);
|
||||
List<Integer> list = table.get(name);
|
||||
if (dm instanceof Integer) {//数据表只允许加入Double类型数据
|
||||
list.add((int) dm);
|
||||
} else {
|
||||
throw new Exception("type not Integer");
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,123 @@
|
||||
package org.wlld.randomForest;
|
||||
|
||||
import org.wlld.tools.ArithUtil;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @description
|
||||
* @date 3:12 下午 2020/2/17
|
||||
*/
|
||||
public class Tree {//决策树
|
||||
private DataTable dataTable;
|
||||
private Map<String, List<Integer>> table;//总样本
|
||||
private Node rootNode;//根节点
|
||||
private List<Integer> endList;//最终结果分类
|
||||
|
||||
private class Node {
|
||||
private Map<String, List<Integer>> fatherTable;//父级样本
|
||||
private Set<String> attribute;//当前可用属性
|
||||
private double Ent;//信息熵
|
||||
private List<Node> nodeList;//下属节点
|
||||
}
|
||||
|
||||
private class Gain {
|
||||
private double gain;
|
||||
private double IV;
|
||||
}
|
||||
|
||||
Tree(DataTable dataTable) throws Exception {
|
||||
if (dataTable.getKey() != null && dataTable.getLength() > 0) {
|
||||
table = dataTable.getTable();
|
||||
this.dataTable = dataTable;
|
||||
} else {
|
||||
throw new Exception("dataTable is empty");
|
||||
}
|
||||
}
|
||||
|
||||
private double log2(double p) {
|
||||
return ArithUtil.div(Math.log(p), Math.log(2));
|
||||
}
|
||||
|
||||
private double getEnt(List<Integer> list) {
|
||||
//记录了每个类别有几个
|
||||
Map<Integer, Integer> myType = new HashMap<>();
|
||||
for (int index : list) {
|
||||
int type = endList.get(index);//最终结果的类别
|
||||
if (myType.containsKey(type)) {
|
||||
myType.put(type, myType.get(type) + 1);
|
||||
} else {
|
||||
myType.put(type, 1);
|
||||
}
|
||||
}
|
||||
double ent = 0;
|
||||
//求信息熵
|
||||
for (Map.Entry<Integer, Integer> entry1 : myType.entrySet()) {
|
||||
double g = ArithUtil.div(entry1.getValue(), list.size());
|
||||
ent = ArithUtil.add(ent, ArithUtil.mul(g, log2(g)));
|
||||
}
|
||||
return -ent;
|
||||
}
|
||||
|
||||
private double getGain(double ent, double dNub, double gain) {
|
||||
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
|
||||
}
|
||||
|
||||
private Gain getGainNode(List<Integer> dataBodyList, double fatherEnt) {
|
||||
Map<Integer, List<Integer>> map = new HashMap<>();
|
||||
int fatherNub = dataBodyList.size();//总样本数
|
||||
double gain = 0;//信息增益
|
||||
double IV = 0;//增益率
|
||||
//该属性每个离散数据分类的集合
|
||||
for (int i = 0; i < dataBodyList.size(); i++) {
|
||||
int classification = dataBodyList.get(i);//当前属性
|
||||
if (map.containsKey(classification)) {
|
||||
List<Integer> list = map.get(classification);
|
||||
list.add(i);
|
||||
} else {
|
||||
List<Integer> list = new ArrayList<>();
|
||||
list.add(i);
|
||||
map.put(classification, list);
|
||||
}
|
||||
}
|
||||
//求信息增益
|
||||
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
|
||||
List<Integer> list = entry.getValue();
|
||||
int myNub = list.size();
|
||||
double ent = getEnt(list);//每一个信息熵都是一个子集
|
||||
double dNub = ArithUtil.div(myNub, fatherNub);
|
||||
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
|
||||
gain = getGain(ent, dNub, gain);
|
||||
}
|
||||
Gain gain1 = new Gain();
|
||||
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
|
||||
gain1.IV = -IV;
|
||||
return gain1;
|
||||
}
|
||||
|
||||
private Node createNode(Node node) {
|
||||
Map<String, List<Integer>> fatherTable = node.fatherTable;
|
||||
Set<String> attributes = node.attribute;
|
||||
double fatherEnt = node.Ent;
|
||||
for (String name : attributes) {
|
||||
List<Integer> dataBodyList = fatherTable.get(name);
|
||||
Gain gain = getGainNode(dataBodyList, fatherEnt);//信息增益
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public void study() throws Exception {
|
||||
if (dataTable != null) {
|
||||
endList = dataTable.getTable().get(dataTable.getKey());
|
||||
Node node = new Node();
|
||||
node.attribute = dataTable.getKeyType();//当前可用属性
|
||||
node.fatherTable = table;//当前父级样本
|
||||
node.Ent = getEnt(endList);
|
||||
createNode(node);
|
||||
} else {
|
||||
throw new Exception("dataTable is null");
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package org.wlld;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @description
|
||||
* @date 8:11 上午 2020/2/18
|
||||
*/
|
||||
public class Food {
|
||||
private int foodId;
|
||||
private double testId;
|
||||
|
||||
public int getFoodId() {
|
||||
return foodId;
|
||||
}
|
||||
|
||||
public void setFoodId(int foodId) {
|
||||
this.foodId = foodId;
|
||||
}
|
||||
|
||||
public double getTestId() {
|
||||
return testId;
|
||||
}
|
||||
|
||||
public void setTestId(double testId) {
|
||||
this.testId = testId;
|
||||
}
|
||||
}
|
Loading…
Reference in new issue