增加二叉树回归及后剪枝

pull/57/head
lidapeng 5 years ago
parent 2fdc4ba80d
commit dfe0b9576d

@ -165,6 +165,7 @@ public class Forest extends Frequency {
if (sub < shrinkParameter) {//需要剪枝,通知父级
fatherForest.getSonMessage(true, id);
isRemove = true;
//System.out.println("剪枝id==" + id + ",sub==" + sub + ",th==" + shrinkParameter);
} else {//通知父级,不需要剪枝,并将父级改为不可移除
fatherForest.getSonMessage(false, id);
}
@ -186,7 +187,6 @@ public class Forest extends Frequency {
public void cut() throws Exception {
int y = resultMatrix.getX();
if (y > 200) {
System.out.println("-======================");
double[] dm = findG();
int z = y / 2;
median = dm[z];
@ -201,10 +201,10 @@ public class Forest extends Frequency {
}
int leftId = 2 * id;
int rightId = leftId + 1;
//System.out.println("id:" + id + ",size:" + dm.length);
forestMap.put(id, this);
forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId);
forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId);
forestMap.put(leftId, forestLeft);
forestMap.put(rightId, forestRight);
forestRight.setFather(this);
forestLeft.setFather(this);
Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左

@ -216,7 +216,7 @@ public class RegressionForest extends Frequency {
}
}
//每一层从下到上进行剪枝
for (int i = layersNub - 1; i > 0; i++) {
for (int i = layersNub - 1; i > 0; i--) {
int min = (int) Math.pow(2, i);//最后一层最小的id
int maxNub = (int) Math.pow(2, i + 1);
for (Map.Entry<Integer, Forest> entry : forestMap.entrySet()) {

@ -13,16 +13,11 @@ import java.util.*;
*/
public class ForestTest {
public static void main(String[] args) throws Exception {
//test();
test();
//int a = (int) (Math.log(4) / Math.log(2));//id22是第几层
//double a = Math.pow(2, 5) - 1; 第五层的第一个数
// System.out.println("a==" + a);
TreeMap<Integer, String> map = new TreeMap<>();
map.put(5, "a");
map.put(3, "b");
map.put(4, "c");
map.put(6, "d");
map.put(7, "e");
}
public static void test() throws Exception {//对分段回归进行测试
@ -52,6 +47,8 @@ public class ForestTest {
}
double avs = sigma / size;
System.out.println("a误差" + avs);
// a误差0.0017585065712555645
// b误差0.00761733737464547
sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = b1.get(i);

Loading…
Cancel
Save