增加分段回归

pull/57/head
lidapeng 4 years ago
parent 7a4e006cec
commit c7b00adf0c

@ -97,7 +97,7 @@ public class Forest extends Frequency {
pc1 = g; pc1 = g;
} }
} }
//找到非原始基最离散的新基 //找到非原始基最离散的新基:
if (max > maxOld) {//使用新基 if (max > maxOld) {//使用新基
isOldG = false; isOldG = false;
} else {//使用原有基 } else {//使用原有基
@ -150,17 +150,18 @@ public class Forest extends Frequency {
} }
public void pruning() {//进行后剪枝 public void pruning() {//进行后剪枝
System.out.println("执行了剪枝====");
if (forestLeft != null) { if (forestLeft != null) {
double leftDist = getDist(forestLeft.getW()); double leftDist = getDist(forestLeft.getW());
System.out.println("左剪枝阈值:" + leftDist);
if (leftDist < shrinkParameter) {//剪枝 if (leftDist < shrinkParameter) {//剪枝
System.out.println("成功左剪枝阈值:" + leftDist + ",阈值:" + shrinkParameter);
forestLeft = null; forestLeft = null;
} }
} }
if (forestRight != null) { if (forestRight != null) {
double rightDist = getDist(forestRight.getW()); double rightDist = getDist(forestRight.getW());
System.out.println("右剪枝阈值:" + rightDist);
if (rightDist < shrinkParameter) { if (rightDist < shrinkParameter) {
System.out.println("成功右剪枝阈值:" + rightDist + ",阈值:" + shrinkParameter);
forestRight = null; forestRight = null;
} }
} }
@ -170,6 +171,7 @@ public class Forest extends Frequency {
public void cut() throws Exception { public void cut() throws Exception {
int y = resultMatrix.getX(); int y = resultMatrix.getX();
if (y > 200) { if (y > 200) {
System.out.println("-======================");
double[] dm = findG(); double[] dm = findG();
int z = y / 2; int z = y / 2;
median = dm[z]; median = dm[z];

@ -215,8 +215,8 @@ public class RegressionForest extends Frequency {
if (forest != null) { if (forest != null) {
forest.pruning(); forest.pruning();
Forest forestRight = forest.getForestRight(); Forest forestRight = forest.getForestRight();
pruningTree(forestRight);
Forest forestLeft = forest.getForestLeft(); Forest forestLeft = forest.getForestLeft();
pruningTree(forestRight);
pruningTree(forestLeft); pruningTree(forestLeft);
} }
} }

@ -19,10 +19,10 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试 public static void test() throws Exception {//对分段回归进行测试
int size = 2000; int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.02); RegressionForest regressionForest = new RegressionForest(size, 3, 0.01);
regressionForest.setCosSize(40); regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1); List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.3, 0.2, 0.1, size, 2, 2); List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2);
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
double[] featureA = a.get(i); double[] featureA = a.get(i);
double[] featureB = b.get(i); double[] featureB = b.get(i);
@ -32,21 +32,24 @@ public class ForestTest {
regressionForest.insertFeature(testB, featureB[2]); regressionForest.insertFeature(testB, featureB[2]);
} }
regressionForest.startStudy(); regressionForest.startStudy();
///
List<double[]> a1 = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b1 = fun(0.3, 0.2, 0.6, size, 2, 2);
double sigma = 0; double sigma = 0;
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
double[] feature = a.get(i); double[] feature = a1.get(i);
double[] test = new double[]{feature[0], feature[1]}; double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]); double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + Math.pow(dist, 2); sigma = sigma + dist;
} }
double avs = sigma / size; double avs = sigma / size;
System.out.println("a误差" + avs); System.out.println("a误差" + avs);
sigma = 0; sigma = 0;
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
double[] feature = b.get(i); double[] feature = b1.get(i);
double[] test = new double[]{feature[0], feature[1]}; double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]); double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + Math.pow(dist, 2); sigma = sigma + dist;
} }
double avs2 = sigma / size; double avs2 = sigma / size;
System.out.println("b误差" + avs2); System.out.println("b误差" + avs2);

Loading…
Cancel
Save