|
|
|
@ -24,7 +24,7 @@ public class RegressionForest extends Frequency {
|
|
|
|
|
public RegressionForest(int size, int featureNub) throws Exception {//初始化
|
|
|
|
|
if (size > 0 && featureNub > 0) {
|
|
|
|
|
this.featureNub = featureNub;
|
|
|
|
|
w = new double[size];
|
|
|
|
|
w = new double[featureNub];
|
|
|
|
|
results = new double[size];
|
|
|
|
|
conditionMatrix = new Matrix(size, featureNub);
|
|
|
|
|
resultMatrix = new Matrix(size, 1);
|
|
|
|
@ -37,8 +37,56 @@ public class RegressionForest extends Frequency {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public void getDist(double[] feature, double result) {//获取特征误差结果
|
|
|
|
|
public double getDist(double[] feature, double result) {//获取特征误差结果
|
|
|
|
|
Forest forestFinish;
|
|
|
|
|
if (result <= min) {//直接找下边界区域
|
|
|
|
|
forestFinish = getLimitRegion(forest, false);
|
|
|
|
|
} else if (result >= max) {//直接找到上边界区域
|
|
|
|
|
forestFinish = getLimitRegion(forest, true);
|
|
|
|
|
} else {
|
|
|
|
|
forestFinish = getRegion(forest, result);
|
|
|
|
|
}
|
|
|
|
|
//计算误差
|
|
|
|
|
double[] w = forestFinish.getW();
|
|
|
|
|
double sigma = 0;
|
|
|
|
|
for (int i = 0; i < w.length; i++) {
|
|
|
|
|
double nub;
|
|
|
|
|
if (i < w.length - 1) {
|
|
|
|
|
nub = w[i] * feature[i];
|
|
|
|
|
} else {
|
|
|
|
|
nub = w[i];
|
|
|
|
|
}
|
|
|
|
|
sigma = sigma + nub;
|
|
|
|
|
}
|
|
|
|
|
return Math.abs(result - sigma);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private Forest getRegion(Forest forest, double result) {
|
|
|
|
|
double median = forest.getMedian();
|
|
|
|
|
if (median > 0) {//进行了拆分
|
|
|
|
|
if (result > median) {//向右走
|
|
|
|
|
forest = forest.getForestRight();
|
|
|
|
|
} else {//向左走
|
|
|
|
|
forest = forest.getForestLeft();
|
|
|
|
|
}
|
|
|
|
|
return getRegion(forest, result);
|
|
|
|
|
} else {//没有拆分
|
|
|
|
|
return forest;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private Forest getLimitRegion(Forest forest, boolean isMax) {
|
|
|
|
|
Forest forestSon;
|
|
|
|
|
if (isMax) {
|
|
|
|
|
forestSon = forest.getForestRight();
|
|
|
|
|
} else {
|
|
|
|
|
forestSon = forest.getForestLeft();
|
|
|
|
|
}
|
|
|
|
|
if (forestSon != null) {
|
|
|
|
|
return getLimitRegion(forestSon, isMax);
|
|
|
|
|
} else {
|
|
|
|
|
return forest;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public void insertFeature(double[] feature, double result) throws Exception {//插入数据
|
|
|
|
@ -99,7 +147,7 @@ public class RegressionForest extends Frequency {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private void regression(Forest forest) throws Exception {
|
|
|
|
|
private void regression(Forest forest) throws Exception {//对分段进行线性回归
|
|
|
|
|
Matrix conditionMatrix = forest.getConditionMatrix();
|
|
|
|
|
Matrix resultMatrix = forest.getResultMatrix();
|
|
|
|
|
double[] w = forest.getW();
|
|
|
|
|