增加决策树工具包

pull/1/head
lidapeng 5 years ago
parent 6a0ba2b7f6
commit 4897a38e68

@ -4,8 +4,8 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>org.wlld</groupId> <groupId>com.github</groupId>
<artifactId>myBrain</artifactId> <artifactId>ImageMarket</artifactId>
<version>1.0.0</version> <version>1.0.0</version>
<name>myBrain</name> <name>myBrain</name>

@ -0,0 +1,68 @@
package org.wlld.randomForest;
import java.lang.reflect.Method;
import java.util.*;
/**
* @author lidapeng
* @description
* @date 3:48 2020/2/17
*/
public class DataTable {//数据表
private Map<String, List<Integer>> table = new HashMap<>();
private Set<String> keyType;//表的属性
private String key;//最终分类字段
private int length;
public String getKey() {
return key;
}
public int getLength() {
return length;
}
public Map<String, List<Integer>> getTable() {
return table;
}
public Set<String> getKeyType() {
return keyType;
}
public void setKey(String key) throws Exception {
if (keyType.contains(key)) {
this.key = key;
} else {
throw new Exception("NOT FIND KEY");
}
}
public DataTable(Set<String> key) {
this.keyType = key;
for (String name : key) {
table.put(name, new ArrayList<>());
}
}
public void insert(Object ob) {
try {
Class<?> body = ob.getClass();
length++;
for (String name : keyType) {
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
Method method = body.getMethod(methodName);
Object dm = method.invoke(ob);
List<Integer> list = table.get(name);
if (dm instanceof Integer) {//数据表只允许加入Double类型数据
list.add((int) dm);
} else {
throw new Exception("type not Integer");
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}

@ -0,0 +1,123 @@
package org.wlld.randomForest;
import org.wlld.tools.ArithUtil;
import java.util.*;
/**
* @author lidapeng
* @description
* @date 3:12 2020/2/17
*/
public class Tree {//决策树
private DataTable dataTable;
private Map<String, List<Integer>> table;//总样本
private Node rootNode;//根节点
private List<Integer> endList;//最终结果分类
private class Node {
private Map<String, List<Integer>> fatherTable;//父级样本
private Set<String> attribute;//当前可用属性
private double Ent;//信息熵
private List<Node> nodeList;//下属节点
}
private class Gain {
private double gain;
private double IV;
}
Tree(DataTable dataTable) throws Exception {
if (dataTable.getKey() != null && dataTable.getLength() > 0) {
table = dataTable.getTable();
this.dataTable = dataTable;
} else {
throw new Exception("dataTable is empty");
}
}
private double log2(double p) {
return ArithUtil.div(Math.log(p), Math.log(2));
}
private double getEnt(List<Integer> list) {
//记录了每个类别有几个
Map<Integer, Integer> myType = new HashMap<>();
for (int index : list) {
int type = endList.get(index);//最终结果的类别
if (myType.containsKey(type)) {
myType.put(type, myType.get(type) + 1);
} else {
myType.put(type, 1);
}
}
double ent = 0;
//求信息熵
for (Map.Entry<Integer, Integer> entry1 : myType.entrySet()) {
double g = ArithUtil.div(entry1.getValue(), list.size());
ent = ArithUtil.add(ent, ArithUtil.mul(g, log2(g)));
}
return -ent;
}
private double getGain(double ent, double dNub, double gain) {
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
}
private Gain getGainNode(List<Integer> dataBodyList, double fatherEnt) {
Map<Integer, List<Integer>> map = new HashMap<>();
int fatherNub = dataBodyList.size();//总样本数
double gain = 0;//信息增益
double IV = 0;//增益率
//该属性每个离散数据分类的集合
for (int i = 0; i < dataBodyList.size(); i++) {
int classification = dataBodyList.get(i);//当前属性
if (map.containsKey(classification)) {
List<Integer> list = map.get(classification);
list.add(i);
} else {
List<Integer> list = new ArrayList<>();
list.add(i);
map.put(classification, list);
}
}
//求信息增益
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
List<Integer> list = entry.getValue();
int myNub = list.size();
double ent = getEnt(list);//每一个信息熵都是一个子集
double dNub = ArithUtil.div(myNub, fatherNub);
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
gain = getGain(ent, dNub, gain);
}
Gain gain1 = new Gain();
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
gain1.IV = -IV;
return gain1;
}
private Node createNode(Node node) {
Map<String, List<Integer>> fatherTable = node.fatherTable;
Set<String> attributes = node.attribute;
double fatherEnt = node.Ent;
for (String name : attributes) {
List<Integer> dataBodyList = fatherTable.get(name);
Gain gain = getGainNode(dataBodyList, fatherEnt);//信息增益
}
return null;
}
public void study() throws Exception {
if (dataTable != null) {
endList = dataTable.getTable().get(dataTable.getKey());
Node node = new Node();
node.attribute = dataTable.getKeyType();//当前可用属性
node.fatherTable = table;//当前父级样本
node.Ent = getEnt(endList);
createNode(node);
} else {
throw new Exception("dataTable is null");
}
}
}

@ -0,0 +1,27 @@
package org.wlld;
/**
* @author lidapeng
* @description
* @date 8:11 2020/2/18
*/
public class Food {
private int foodId;
private double testId;
public int getFoodId() {
return foodId;
}
public void setFoodId(int foodId) {
this.foodId = foodId;
}
public double getTestId() {
return testId;
}
public void setTestId(double testId) {
this.testId = testId;
}
}

@ -22,11 +22,14 @@ import java.util.Map;
*/ */
public class HelloWorld { public class HelloWorld {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
test(); int a = ModelData.DATA2.length();
System.out.println(a);
//test();
//testPic(); //testPic();
//testModel(); //testModel();
} }
public static void test() throws Exception { public static void test() throws Exception {
Picture picture = new Picture(); Picture picture = new Picture();
TempleConfig templeConfig = new TempleConfig(); TempleConfig templeConfig = new TempleConfig();
@ -54,8 +57,8 @@ public class HelloWorld {
// } // }
// templeConfig.boxStudy();//边框聚类 // templeConfig.boxStudy();//边框聚类
// //精准模式在全部学习结束的时候一定要使用此方法,速度模式不要调用此方法 // //精准模式在全部学习结束的时候一定要使用此方法,速度模式不要调用此方法
templeConfig.startLvq();//原型向量量化 templeConfig.startLvq();//原型向量量化
templeConfig.boxStudy();//边框回归 templeConfig.boxStudy();//边框回归
for (int j = 1; j < 2; j++) { for (int j = 1; j < 2; j++) {
Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png"); Matrix right = picture.getImageMatrixByLocal("/Users/lidapeng/Desktop/myDocment/c/c" + j + ".png");
Map<Integer, List<FrameBody>> map = operation.lookWithPosition(right, j); Map<Integer, List<FrameBody>> map = operation.lookWithPosition(right, j);

@ -3,6 +3,8 @@ package org.wlld;
import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation; import org.wlld.MatrixTools.MatrixOperation;
import java.util.*;
/** /**
* @author lidapeng * @author lidapeng
* @description * @description
@ -10,15 +12,16 @@ import org.wlld.MatrixTools.MatrixOperation;
*/ */
public class MatrixTest { public class MatrixTest {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
test4(); Map<Double, String> map = new TreeMap<>();
} map.put(3.0, "a");
map.put(2.0, "b");
map.put(4.0, "c");
map.put(5.0, "d");
map.put(1.0, "e");
for (Map.Entry<Double, String> entry : map.entrySet()) {
System.out.println(entry.getKey());
}
public static void test4() throws Exception {
Matrix matrix = new Matrix(1, 12);
String a = "[1,2,3,4,5,6,7,8,9,10,11,12]#";
matrix.setAll(a);
Matrix matrix1 = MatrixOperation.getPoolVector(matrix);
System.out.println(matrix1.getString());
} }
public static void test3() throws Exception { public static void test3() throws Exception {

Loading…
Cancel
Save