diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index 431394d..7ba3e49 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -18,8 +18,6 @@ public class Forest extends Frequency { private Forest forestLeft;//左森林 private Forest forestRight;//右森林 private int featureSize; - private double min;//下限 - private double max;//上限 private double resultVariance;//结果矩阵方差 private double median;//结果矩阵中位数 private double shrinkParameter;//方差收缩参数 @@ -31,6 +29,10 @@ public class Forest extends Frequency { w = new double[featureSize]; } + public double getMedian() { + return median; + } + public double getResultVariance() { return resultVariance; } @@ -85,36 +87,14 @@ public class Forest extends Frequency { double leftVar = variance(resultLeft); double rightVar = variance(resultRight); double variance = resultVariance * shrinkParameter; - if (leftVar < variance || rightVar < variance) {//继续拆分 - double[] left = getLimit(resultLeft); - double[] right = getLimit(resultRight); - forestLeft.setMin(left[0]); - forestLeft.setMax(left[1]); - forestRight.setMin(right[0]); - forestRight.setMax(right[1]); - } else {//不继续拆分 + if (leftVar > variance && rightVar > variance) {//不进行拆分,回退 forestLeft = null; forestRight = null; + median = 0; } } } - public double getMin() { - return min; - } - - public void setMin(double min) { - this.min = min; - } - - public double getMax() { - return max; - } - - public void setMax(double max) { - this.max = max; - } - public Matrix getConditionMatrix() { return conditionMatrix; } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 99a254d..14f7fa7 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -24,7 +24,7 @@ public class RegressionForest extends Frequency { public RegressionForest(int size, int featureNub) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; - w = new double[size]; + w = new double[featureNub]; results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); @@ -37,8 +37,56 @@ public class RegressionForest extends Frequency { } } - public void getDist(double[] feature, double result) {//获取特征误差结果 + public double getDist(double[] feature, double result) {//获取特征误差结果 + Forest forestFinish; + if (result <= min) {//直接找下边界区域 + forestFinish = getLimitRegion(forest, false); + } else if (result >= max) {//直接找到上边界区域 + forestFinish = getLimitRegion(forest, true); + } else { + forestFinish = getRegion(forest, result); + } + //计算误差 + double[] w = forestFinish.getW(); + double sigma = 0; + for (int i = 0; i < w.length; i++) { + double nub; + if (i < w.length - 1) { + nub = w[i] * feature[i]; + } else { + nub = w[i]; + } + sigma = sigma + nub; + } + return Math.abs(result - sigma); + } + private Forest getRegion(Forest forest, double result) { + double median = forest.getMedian(); + if (median > 0) {//进行了拆分 + if (result > median) {//向右走 + forest = forest.getForestRight(); + } else {//向左走 + forest = forest.getForestLeft(); + } + return getRegion(forest, result); + } else {//没有拆分 + return forest; + } + } + + private Forest getLimitRegion(Forest forest, boolean isMax) { + Forest forestSon; + if (isMax) { + forestSon = forest.getForestRight(); + } else { + forestSon = forest.getForestLeft(); + } + if (forestSon != null) { + return getLimitRegion(forestSon, isMax); + } else { + return forest; + } } public void insertFeature(double[] feature, double result) throws Exception {//插入数据 @@ -99,7 +147,7 @@ public class RegressionForest extends Frequency { } - private void regression(Forest forest) throws Exception { + private void regression(Forest forest) throws Exception {//对分段进行线性回归 Matrix conditionMatrix = forest.getConditionMatrix(); Matrix resultMatrix = forest.getResultMatrix(); double[] w = forest.getW();