增加分段回归

pull/57/head
lidapeng 4 years ago
parent 9e72af8212
commit 5d53a3415d

@ -150,7 +150,29 @@ public class Forest extends Frequency {
return data;
}
public void pruning() {//进行剪枝
private double getDist(double[] data) {
int len = data.length;
double sigma = 0;
for (int i = 0; i < len; i++) {
double sub = data[i] - w[i];
sigma = sigma + Math.pow(sub, 2);
}
return sigma / len;
}
public void pruning() {//进行后剪枝
if (forestLeft != null) {
double leftDist = getDist(forestLeft.getW());
if (leftDist < shrinkParameter) {//剪枝
forestLeft = null;
}
}
if (forestRight != null) {
double rightDist = getDist(forestRight.getW());
if (rightDist < shrinkParameter) {
forestRight = null;
}
}
}

@ -188,6 +188,8 @@ public class RegressionForest extends Frequency {
start(forest);
//进行回归
regression();
//进行剪枝
pruning();
} else {
throw new Exception("rootForest is null");
}
@ -203,6 +205,24 @@ public class RegressionForest extends Frequency {
}
}
private void pruning() throws Exception {
if (forest != null) {
pruningTree(forest);
} else {
throw new Exception("rootForest is null");
}
}
private void pruningTree(Forest forest) {
if (forest != null) {
forest.pruning();
Forest forestRight = forest.getForestRight();
pruningTree(forestRight);
Forest forestLeft = forest.getForestLeft();
pruningTree(forestLeft);
}
}
private void regression() throws Exception {//开始进行回归
if (forest != null) {
regressionTree(forest);

Loading…
Cancel
Save