From d17b58c3c11556743afde1aa0d8604d753df5630 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Mon, 7 Sep 2020 17:22:21 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5=E5=9B=9E?= =?UTF-8?q?=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/regressionForest/Forest.java | 27 +++++--- .../regressionForest/RegressionForest.java | 69 +++++++++++++++++-- 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index c5accd5..431394d 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -17,7 +17,7 @@ public class Forest extends Frequency { private Matrix resultMatrix;//结果矩阵 private Forest forestLeft;//左森林 private Forest forestRight;//右森林 - private int size; + private int featureSize; private double min;//下限 private double max;//上限 private double resultVariance;//结果矩阵方差 @@ -25,9 +25,10 @@ public class Forest extends Frequency { private double shrinkParameter;//方差收缩参数 private double[] w; - public Forest(int size, double shrinkParameter) { - this.size = size; + public Forest(int featureSize, double shrinkParameter) { + this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; + w = new double[featureSize]; } public double getResultVariance() { @@ -48,10 +49,10 @@ public class Forest extends Frequency { Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; - forestLeft = new Forest(size, shrinkParameter); - forestRight = new Forest(size, shrinkParameter); - Matrix conditionMatrixLeft = new Matrix(z, size);//条件矩阵左 - Matrix conditionMatrixRight = new Matrix(y - z, size);//条件矩阵右 + forestLeft = new Forest(featureSize, shrinkParameter); + forestRight = new Forest(featureSize, shrinkParameter); + Matrix conditionMatrixLeft = new Matrix(z, featureSize);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(y - z, featureSize);//条件矩阵右 Matrix resultMatrixLeft = new Matrix(z, 1);//结果矩阵左 Matrix resultMatrixRight = new Matrix(y - z, 1);//结果矩阵右 forestLeft.setConditionMatrix(conditionMatrixLeft); @@ -65,14 +66,14 @@ public class Forest extends Frequency { for (int i = 0; i < y; i++) { double nub = resultMatrix.getNumber(i, 0);//结果矩阵 if (nub > median) {//进入右森林并计算右森林结果矩阵方差 - for (int j = 0; j < size; j++) {//进入右森林的条件矩阵 + for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); } resultRight[rightIndex] = nub; resultMatrixRight.setNub(rightIndex, 0, nub); rightIndex++; } else {//进入左森林并计算左森林结果矩阵方差 - for (int j = 0; j < size; j++) {//进入右森林的条件矩阵 + for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixLeft.setNub(leftIndex, j, conditionMatrix.getNumber(i, j)); } resultLeft[leftIndex] = nub; @@ -137,4 +138,12 @@ public class Forest extends Frequency { public void setW(double[] w) { this.w = w; } + + public Forest getForestLeft() { + return forestLeft; + } + + public Forest getForestRight() { + return forestRight; + } } diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index d3690e0..99a254d 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -2,6 +2,7 @@ package org.wlld.regressionForest; import org.wlld.MatrixTools.Matrix; import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.tools.Frequency; /** * @param @@ -9,30 +10,44 @@ import org.wlld.MatrixTools.MatrixOperation; * @Author LiDaPeng * @Description 回归森林 */ -public class RegressionForest { +public class RegressionForest extends Frequency { private double[] w; private Matrix conditionMatrix;//条件矩阵 private Matrix resultMatrix;//结果矩阵 + private Forest forest; private int featureNub;//特征数量 private int xIndex = 0;//记录插入位置 + private double[] results;//结果数组 + private double min;//结果最小值 + private double max;//结果最大值 public RegressionForest(int size, int featureNub) throws Exception {//初始化 if (size > 0 && featureNub > 0) { this.featureNub = featureNub; w = new double[size]; + results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); + forest = new Forest(featureNub, 0.9); + forest.setW(w); + forest.setConditionMatrix(conditionMatrix); + forest.setResultMatrix(resultMatrix); } else { throw new Exception("size and featureNub too small"); } } + public void getDist(double[] feature, double result) {//获取特征误差结果 + + } + public void insertFeature(double[] feature, double result) throws Exception {//插入数据 if (feature.length == featureNub) { for (int i = 0; i < featureNub; i++) { if (i < featureNub - 1) { conditionMatrix.setNub(xIndex, i, feature[i]); } else { + results[xIndex] = result; conditionMatrix.setNub(xIndex, i, 1.0); resultMatrix.setNub(xIndex, 0, result); } @@ -43,14 +58,54 @@ public class RegressionForest { } } + public void start() throws Exception {//开始进行分段 + if (forest != null) { + double[] limit = getLimit(results); + min = limit[0]; + max = limit[1]; + start(forest); + } else { + throw new Exception("rootForest is null"); + } + } + + private void start(Forest forest) throws Exception { + forest.cut(); + Forest forestLeft = forest.getForestLeft(); + Forest forestRight = forest.getForestRight(); + if (forestLeft != null && forestRight != null) { + start(forestLeft); + start(forestRight); + } + + } + public void regression() throws Exception {//开始进行回归 - if (xIndex > 0) { - Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); - for (int i = 0; i < ws.getX(); i++) { - w[i] = ws.getNumber(i, 0); - } + if (forest != null) { + regressionTree(forest); } else { - throw new Exception("regression matrix size is zero"); + throw new Exception("rootForest is null"); + } + } + + private void regressionTree(Forest forest) throws Exception { + regression(forest); + Forest forestLeft = forest.getForestLeft(); + Forest forestRight = forest.getForestRight(); + if (forestLeft != null && forestRight != null) { + regressionTree(forestLeft); + regressionTree(forestRight); + } + + } + + private void regression(Forest forest) throws Exception { + Matrix conditionMatrix = forest.getConditionMatrix(); + Matrix resultMatrix = forest.getResultMatrix(); + double[] w = forest.getW(); + Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); + for (int i = 0; i < ws.getX(); i++) { + w[i] = ws.getNumber(i, 0); } } } \ No newline at end of file