增加分段回归

pull/57/head
lidapeng 4 years ago
parent c13848884b
commit 7a4e006cec

@ -49,17 +49,6 @@ public class Forest extends Frequency {
this.resultVariance = resultVariance; this.resultVariance = resultVariance;
} }
//检测中位数median有多少个一样的值
private int getEqualNub(double median, double[] dm) {
int equalNub = 0;
for (int i = 0; i < dm.length; i++) {
if (median == dm[i]) {
equalNub++;
}
}
return equalNub;
}
private double[] findG() throws Exception {//寻找新的切入维度 private double[] findG() throws Exception {//寻找新的切入维度
// 先尝试从原有维度切入 // 先尝试从原有维度切入
int xSize = conditionMatrix.getX(); int xSize = conditionMatrix.getX();
@ -163,12 +152,14 @@ public class Forest extends Frequency {
public void pruning() {//进行后剪枝 public void pruning() {//进行后剪枝
if (forestLeft != null) { if (forestLeft != null) {
double leftDist = getDist(forestLeft.getW()); double leftDist = getDist(forestLeft.getW());
System.out.println("左剪枝阈值:" + leftDist);
if (leftDist < shrinkParameter) {//剪枝 if (leftDist < shrinkParameter) {//剪枝
forestLeft = null; forestLeft = null;
} }
} }
if (forestRight != null) { if (forestRight != null) {
double rightDist = getDist(forestRight.getW()); double rightDist = getDist(forestRight.getW());
System.out.println("右剪枝阈值:" + rightDist);
if (rightDist < shrinkParameter) { if (rightDist < shrinkParameter) {
forestRight = null; forestRight = null;
} }
@ -178,19 +169,25 @@ public class Forest extends Frequency {
public void cut() throws Exception { public void cut() throws Exception {
int y = resultMatrix.getX(); int y = resultMatrix.getX();
if (y > 8) { if (y > 200) {
double[] dm = findG(); double[] dm = findG();
Arrays.sort(dm);//排序
int z = y / 2; int z = y / 2;
median = dm[z]; median = dm[z];
//检测中位数median有多少个一样的值 int rightNub = 0;
int equalNub = getEqualNub(median, dm); int leftNub = 0;
for (int i = 0; i < dm.length; i++) {
if (dm[i] > median) {
rightNub++;
} else {
leftNub++;
}
}
forestLeft = new Forest(featureSize, shrinkParameter, pc); forestLeft = new Forest(featureSize, shrinkParameter, pc);
forestRight = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc);
Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左
Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右 Matrix conditionMatrixRight = new Matrix(rightNub, featureSize);//条件矩阵右
Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左 Matrix resultMatrixLeft = new Matrix(leftNub, 1);//结果矩阵左
Matrix resultMatrixRight = new Matrix(y - z - equalNub, 1);//结果矩阵右 Matrix resultMatrixRight = new Matrix(rightNub, 1);//结果矩阵右
forestLeft.setConditionMatrix(conditionMatrixLeft); forestLeft.setConditionMatrix(conditionMatrixLeft);
forestLeft.setResultMatrix(resultMatrixLeft); forestLeft.setResultMatrix(resultMatrixLeft);
forestRight.setConditionMatrix(conditionMatrixRight); forestRight.setConditionMatrix(conditionMatrixRight);

@ -79,16 +79,14 @@ public class RegressionForest extends Frequency {
private Forest getRegion(Forest forest, double result) { private Forest getRegion(Forest forest, double result) {
double median = forest.getMedian(); double median = forest.getMedian();
if (median > 0) {//进行了拆分 if (result > median && forest.getForestRight() != null) {//向右走
if (result > median) {//向右走
forest = forest.getForestRight(); forest = forest.getForestRight();
} else {//向左走 } else if (result <= median && forest.getForestLeft() != null) {//向左走
forest = forest.getForestLeft(); forest = forest.getForestLeft();
} } else {
return getRegion(forest, result);
} else {//没有拆分
return forest; return forest;
} }
return getRegion(forest, result);
} }
private Forest getLimitRegion(Forest forest, boolean isMax) { private Forest getLimitRegion(Forest forest, boolean isMax) {

@ -19,7 +19,7 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试 public static void test() throws Exception {//对分段回归进行测试
int size = 2000; int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); RegressionForest regressionForest = new RegressionForest(size, 3, 0.02);
regressionForest.setCosSize(40); regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1); List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.3, 0.2, 0.1, size, 2, 2); List<double[]> b = fun(0.3, 0.2, 0.1, size, 2, 2);

Loading…
Cancel
Save