增加分段回归

pull/57/head
lidapeng 4 years ago
parent 7a4e006cec
commit c7b00adf0c

@ -97,7 +97,7 @@ public class Forest extends Frequency {
pc1 = g;
}
}
//找到非原始基最离散的新基
//找到非原始基最离散的新基:
if (max > maxOld) {//使用新基
isOldG = false;
} else {//使用原有基
@ -150,17 +150,18 @@ public class Forest extends Frequency {
}
public void pruning() {//进行后剪枝
System.out.println("执行了剪枝====");
if (forestLeft != null) {
double leftDist = getDist(forestLeft.getW());
System.out.println("左剪枝阈值:" + leftDist);
if (leftDist < shrinkParameter) {//剪枝
System.out.println("成功左剪枝阈值:" + leftDist + ",阈值:" + shrinkParameter);
forestLeft = null;
}
}
if (forestRight != null) {
double rightDist = getDist(forestRight.getW());
System.out.println("右剪枝阈值:" + rightDist);
if (rightDist < shrinkParameter) {
System.out.println("成功右剪枝阈值:" + rightDist + ",阈值:" + shrinkParameter);
forestRight = null;
}
}
@ -170,6 +171,7 @@ public class Forest extends Frequency {
public void cut() throws Exception {
int y = resultMatrix.getX();
if (y > 200) {
System.out.println("-======================");
double[] dm = findG();
int z = y / 2;
median = dm[z];

@ -215,8 +215,8 @@ public class RegressionForest extends Frequency {
if (forest != null) {
forest.pruning();
Forest forestRight = forest.getForestRight();
pruningTree(forestRight);
Forest forestLeft = forest.getForestLeft();
pruningTree(forestRight);
pruningTree(forestLeft);
}
}

@ -19,10 +19,10 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试
int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.02);
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01);
regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.3, 0.2, 0.1, size, 2, 2);
List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2);
for (int i = 0; i < 1000; i++) {
double[] featureA = a.get(i);
double[] featureB = b.get(i);
@ -32,21 +32,24 @@ public class ForestTest {
regressionForest.insertFeature(testB, featureB[2]);
}
regressionForest.startStudy();
///
List<double[]> a1 = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b1 = fun(0.3, 0.2, 0.6, size, 2, 2);
double sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = a.get(i);
double[] feature = a1.get(i);
double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + Math.pow(dist, 2);
sigma = sigma + dist;
}
double avs = sigma / size;
System.out.println("a误差" + avs);
sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = b.get(i);
double[] feature = b1.get(i);
double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + Math.pow(dist, 2);
sigma = sigma + dist;
}
double avs2 = sigma / size;
System.out.println("b误差" + avs2);

Loading…
Cancel
Save