增加分段回归

pull/57/head
lidapeng 4 years ago
parent c13848884b
commit 7a4e006cec

@ -49,17 +49,6 @@ public class Forest extends Frequency {
this.resultVariance = resultVariance;
}
//检测中位数median有多少个一样的值
private int getEqualNub(double median, double[] dm) {
int equalNub = 0;
for (int i = 0; i < dm.length; i++) {
if (median == dm[i]) {
equalNub++;
}
}
return equalNub;
}
private double[] findG() throws Exception {//寻找新的切入维度
// 先尝试从原有维度切入
int xSize = conditionMatrix.getX();
@ -163,12 +152,14 @@ public class Forest extends Frequency {
public void pruning() {//进行后剪枝
if (forestLeft != null) {
double leftDist = getDist(forestLeft.getW());
System.out.println("左剪枝阈值:" + leftDist);
if (leftDist < shrinkParameter) {//剪枝
forestLeft = null;
}
}
if (forestRight != null) {
double rightDist = getDist(forestRight.getW());
System.out.println("右剪枝阈值:" + rightDist);
if (rightDist < shrinkParameter) {
forestRight = null;
}
@ -178,19 +169,25 @@ public class Forest extends Frequency {
public void cut() throws Exception {
int y = resultMatrix.getX();
if (y > 8) {
if (y > 200) {
double[] dm = findG();
Arrays.sort(dm);//排序
int z = y / 2;
median = dm[z];
//检测中位数median有多少个一样的值
int equalNub = getEqualNub(median, dm);
int rightNub = 0;
int leftNub = 0;
for (int i = 0; i < dm.length; i++) {
if (dm[i] > median) {
rightNub++;
} else {
leftNub++;
}
}
forestLeft = new Forest(featureSize, shrinkParameter, pc);
forestRight = new Forest(featureSize, shrinkParameter, pc);
Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左
Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右
Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左
Matrix resultMatrixRight = new Matrix(y - z - equalNub, 1);//结果矩阵右
Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左
Matrix conditionMatrixRight = new Matrix(rightNub, featureSize);//条件矩阵右
Matrix resultMatrixLeft = new Matrix(leftNub, 1);//结果矩阵左
Matrix resultMatrixRight = new Matrix(rightNub, 1);//结果矩阵右
forestLeft.setConditionMatrix(conditionMatrixLeft);
forestLeft.setResultMatrix(resultMatrixLeft);
forestRight.setConditionMatrix(conditionMatrixRight);

@ -79,16 +79,14 @@ public class RegressionForest extends Frequency {
private Forest getRegion(Forest forest, double result) {
double median = forest.getMedian();
if (median > 0) {//进行了拆分
if (result > median) {//向右走
if (result > median && forest.getForestRight() != null) {//向右走
forest = forest.getForestRight();
} else {//向左走
} else if (result <= median && forest.getForestLeft() != null) {//向左走
forest = forest.getForestLeft();
}
return getRegion(forest, result);
} else {//没有拆分
} else {
return forest;
}
return getRegion(forest, result);
}
private Forest getLimitRegion(Forest forest, boolean isMax) {

@ -19,7 +19,7 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试
int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.2);
RegressionForest regressionForest = new RegressionForest(size, 3, 0.02);
regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.3, 0.2, 0.1, size, 2, 2);

Loading…
Cancel
Save