修改决策树

pull/1/head
lidapeng 5 years ago
parent 4897a38e68
commit 63b53a5089

@ -6,8 +6,8 @@
<sourceOutputDir name="target/generated-sources/annotations" /> <sourceOutputDir name="target/generated-sources/annotations" />
<sourceTestOutputDir name="target/generated-test-sources/test-annotations" /> <sourceTestOutputDir name="target/generated-test-sources/test-annotations" />
<outputRelativeToContentRoot value="true" /> <outputRelativeToContentRoot value="true" />
<module name="ImageMarket" />
<module name="myBrain" /> <module name="myBrain" />
<module name="ImageMarket" />
</profile> </profile>
</annotationProcessing> </annotationProcessing>
</component> </component>

@ -16,15 +16,17 @@ public class Tree {//决策树
private List<Integer> endList;//最终结果分类 private List<Integer> endList;//最终结果分类
private class Node { private class Node {
private Map<String, List<Integer>> fatherTable;//父级样本 private boolean isEnd = false;
private List<Integer> fatherList;//父级样本
private Set<String> attribute;//当前可用属性 private Set<String> attribute;//当前可用属性
private double Ent;//信息熵 private String key;//该节点分类属性
private List<Node> nodeList;//下属节点 private List<Node> nodeList;//下属节点
private int type;
} }
private class Gain { private class Gain {
private double gain; private double gain;
private double IV; private double gainRatio;
} }
Tree(DataTable dataTable) throws Exception { Tree(DataTable dataTable) throws Exception {
@ -64,48 +66,116 @@ public class Tree {//决策树
return ArithUtil.add(gain, ArithUtil.mul(ent, dNub)); return ArithUtil.add(gain, ArithUtil.mul(ent, dNub));
} }
private Gain getGainNode(List<Integer> dataBodyList, double fatherEnt) { private List<Node> createNode(Node node) {
Map<Integer, List<Integer>> map = new HashMap<>(); Set<String> attributes = node.attribute;
int fatherNub = dataBodyList.size();//总样本数 List<Integer> fatherList = node.fatherList;
double gain = 0;//信息增益 if (attributes.size() > 0) {
double IV = 0;//增益率 Map<String, Map<Integer, List<Integer>>> mapAll = new HashMap<>();
//该属性每个离散数据分类的集合 double fatherEnt = getEnt(fatherList);
for (int i = 0; i < dataBodyList.size(); i++) { int fatherNub = fatherList.size();//总样本数
int classification = dataBodyList.get(i);//当前属性 //该属性每个离散数据分类的集合
if (map.containsKey(classification)) { for (int i = 0; i < fatherList.size(); i++) {
List<Integer> list = map.get(classification); int index = fatherList.get(i);//编号
list.add(i); for (String attr : attributes) {
if (!mapAll.containsKey(attr)) {
mapAll.put(attr, new HashMap<>());
}
Map<Integer, List<Integer>> map = mapAll.get(attr);
int attrValue = table.get(attr).get(index);
if (!map.containsKey(attrValue)) {
map.put(attrValue, new ArrayList<>());
}
List<Integer> list = map.get(attrValue);
list.add(index);
}
}
Map<String, List<Node>> nodeMap = new HashMap<>();
int i = 0;
double sigmaG = 0;
Map<String, Gain> gainMap = new HashMap<>();
for (Map.Entry<String, Map<Integer, List<Integer>>> mapEntry : mapAll.entrySet()) {
Map<Integer, List<Integer>> map = mapEntry.getValue();
//求信息增益
double gain = 0;//信息增益
double IV = 0;//增益率
List<Node> nodeList = new ArrayList<>();
String name = mapEntry.getKey();
nodeMap.put(name, nodeList);
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {
Set<String> nowAttribute = removeAttribute(attributes, name);
Node sonNode = new Node();
nodeList.add(sonNode);
sonNode.key = mapEntry.getKey();
sonNode.attribute = nowAttribute;
List<Integer> list = entry.getValue();
sonNode.fatherList = list;
int myNub = list.size();
double ent = getEnt(list);//每一个信息熵都是一个子集
double dNub = ArithUtil.div(myNub, fatherNub);
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV);
gain = getGain(ent, dNub, gain);
}
Gain gain1 = new Gain();
gainMap.put(name, gain1);
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
gain1.gainRatio = ArithUtil.div(gain1.gain, -IV);//增益率
sigmaG = ArithUtil.add(gain1.gain, sigmaG);
i++;
}
double avgGain = ArithUtil.div(sigmaG, i);
double gainRatio = 0;//最大增益率
String key = null;//可选属性
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
Gain gain = entry.getValue();
if (gain.gain > avgGain && gain.gainRatio > gainRatio) {
gainRatio = gain.gainRatio;
key = entry.getKey();
}
}
List<Node> nodeList = nodeMap.get(key);
for (int j = 0; j < nodeList.size(); j++) {
Node node1 = nodeList.get(j);
node1.nodeList = createNode(node1);
}
return nodeList;
} else {
//判断类别
node.isEnd = true;
node.type = getType(fatherList);
return null;
}
}
private int getType(List<Integer> list) {
Map<Integer, Integer> myType = new HashMap<>();
for (int index : list) {
int type = endList.get(index);//最终结果的类别
if (myType.containsKey(type)) {
myType.put(type, myType.get(type) + 1);
} else { } else {
List<Integer> list = new ArrayList<>(); myType.put(type, 1);
list.add(i);
map.put(classification, list);
} }
} }
//求信息增益 int type = 0;
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) { int nub = 0;
List<Integer> list = entry.getValue(); for (Map.Entry<Integer, Integer> entry : myType.entrySet()) {
int myNub = list.size(); int nowNub = entry.getValue();
double ent = getEnt(list);//每一个信息熵都是一个子集 if (nowNub > nub) {
double dNub = ArithUtil.div(myNub, fatherNub); type = entry.getKey();
IV = ArithUtil.add(ArithUtil.mul(dNub, log2(dNub)), IV); nub = nowNub;
gain = getGain(ent, dNub, gain); }
} }
Gain gain1 = new Gain(); return type;
gain1.gain = ArithUtil.sub(fatherEnt, gain);//信息增益
gain1.IV = -IV;
return gain1;
} }
private Node createNode(Node node) { private Set<String> removeAttribute(Set<String> attributes, String name) {
Map<String, List<Integer>> fatherTable = node.fatherTable; Set<String> attriBute = new HashSet<>();
Set<String> attributes = node.attribute; for (String myName : attributes) {
double fatherEnt = node.Ent; if (!myName.equals(name)) {
for (String name : attributes) { attriBute.add(myName);
List<Integer> dataBodyList = fatherTable.get(name); }
Gain gain = getGainNode(dataBodyList, fatherEnt);//信息增益
} }
return null; return attriBute;
} }
public void study() throws Exception { public void study() throws Exception {
@ -113,8 +183,11 @@ public class Tree {//决策树
endList = dataTable.getTable().get(dataTable.getKey()); endList = dataTable.getTable().get(dataTable.getKey());
Node node = new Node(); Node node = new Node();
node.attribute = dataTable.getKeyType();//当前可用属性 node.attribute = dataTable.getKeyType();//当前可用属性
node.fatherTable = table;//当前父级样本 List<Integer> list = new ArrayList<>();
node.Ent = getEnt(endList); for (int i = 0; i < endList.size(); i++) {
list.add(i);
}
node.fatherList = list;//当前父级样本
createNode(node); createNode(node);
} else { } else {
throw new Exception("dataTable is null"); throw new Exception("dataTable is null");

Loading…
Cancel
Save