增加分段回归

pull/57/head
lidapeng 4 years ago
parent d17b58c3c1
commit cdeea5bcec

@ -18,8 +18,6 @@ public class Forest extends Frequency {
private Forest forestLeft;//左森林
private Forest forestRight;//右森林
private int featureSize;
private double min;//下限
private double max;//上限
private double resultVariance;//结果矩阵方差
private double median;//结果矩阵中位数
private double shrinkParameter;//方差收缩参数
@ -31,6 +29,10 @@ public class Forest extends Frequency {
w = new double[featureSize];
}
public double getMedian() {
return median;
}
public double getResultVariance() {
return resultVariance;
}
@ -85,36 +87,14 @@ public class Forest extends Frequency {
double leftVar = variance(resultLeft);
double rightVar = variance(resultRight);
double variance = resultVariance * shrinkParameter;
if (leftVar < variance || rightVar < variance) {//继续拆分
double[] left = getLimit(resultLeft);
double[] right = getLimit(resultRight);
forestLeft.setMin(left[0]);
forestLeft.setMax(left[1]);
forestRight.setMin(right[0]);
forestRight.setMax(right[1]);
} else {//不继续拆分
if (leftVar > variance && rightVar > variance) {//不进行拆分,回退
forestLeft = null;
forestRight = null;
median = 0;
}
}
}
public double getMin() {
return min;
}
public void setMin(double min) {
this.min = min;
}
public double getMax() {
return max;
}
public void setMax(double max) {
this.max = max;
}
public Matrix getConditionMatrix() {
return conditionMatrix;
}

@ -24,7 +24,7 @@ public class RegressionForest extends Frequency {
public RegressionForest(int size, int featureNub) throws Exception {//初始化
if (size > 0 && featureNub > 0) {
this.featureNub = featureNub;
w = new double[size];
w = new double[featureNub];
results = new double[size];
conditionMatrix = new Matrix(size, featureNub);
resultMatrix = new Matrix(size, 1);
@ -37,8 +37,56 @@ public class RegressionForest extends Frequency {
}
}
public void getDist(double[] feature, double result) {//获取特征误差结果
public double getDist(double[] feature, double result) {//获取特征误差结果
Forest forestFinish;
if (result <= min) {//直接找下边界区域
forestFinish = getLimitRegion(forest, false);
} else if (result >= max) {//直接找到上边界区域
forestFinish = getLimitRegion(forest, true);
} else {
forestFinish = getRegion(forest, result);
}
//计算误差
double[] w = forestFinish.getW();
double sigma = 0;
for (int i = 0; i < w.length; i++) {
double nub;
if (i < w.length - 1) {
nub = w[i] * feature[i];
} else {
nub = w[i];
}
sigma = sigma + nub;
}
return Math.abs(result - sigma);
}
private Forest getRegion(Forest forest, double result) {
double median = forest.getMedian();
if (median > 0) {//进行了拆分
if (result > median) {//向右走
forest = forest.getForestRight();
} else {//向左走
forest = forest.getForestLeft();
}
return getRegion(forest, result);
} else {//没有拆分
return forest;
}
}
private Forest getLimitRegion(Forest forest, boolean isMax) {
Forest forestSon;
if (isMax) {
forestSon = forest.getForestRight();
} else {
forestSon = forest.getForestLeft();
}
if (forestSon != null) {
return getLimitRegion(forestSon, isMax);
} else {
return forest;
}
}
public void insertFeature(double[] feature, double result) throws Exception {//插入数据
@ -99,7 +147,7 @@ public class RegressionForest extends Frequency {
}
private void regression(Forest forest) throws Exception {
private void regression(Forest forest) throws Exception {//对分段进行线性回归
Matrix conditionMatrix = forest.getConditionMatrix();
Matrix resultMatrix = forest.getResultMatrix();
double[] w = forest.getW();

Loading…
Cancel
Save