增加分段回归

pull/57/head
lidapeng 4 years ago
parent df70d9b20f
commit d17b58c3c1

@ -17,7 +17,7 @@ public class Forest extends Frequency {
private Matrix resultMatrix;//结果矩阵
private Forest forestLeft;//左森林
private Forest forestRight;//右森林
private int size;
private int featureSize;
private double min;//下限
private double max;//上限
private double resultVariance;//结果矩阵方差
@ -25,9 +25,10 @@ public class Forest extends Frequency {
private double shrinkParameter;//方差收缩参数
private double[] w;
public Forest(int size, double shrinkParameter) {
this.size = size;
public Forest(int featureSize, double shrinkParameter) {
this.featureSize = featureSize;
this.shrinkParameter = shrinkParameter;
w = new double[featureSize];
}
public double getResultVariance() {
@ -48,10 +49,10 @@ public class Forest extends Frequency {
Arrays.sort(dm);//排序
int z = y / 2;
median = dm[z];
forestLeft = new Forest(size, shrinkParameter);
forestRight = new Forest(size, shrinkParameter);
Matrix conditionMatrixLeft = new Matrix(z, size);//条件矩阵左
Matrix conditionMatrixRight = new Matrix(y - z, size);//条件矩阵右
forestLeft = new Forest(featureSize, shrinkParameter);
forestRight = new Forest(featureSize, shrinkParameter);
Matrix conditionMatrixLeft = new Matrix(z, featureSize);//条件矩阵左
Matrix conditionMatrixRight = new Matrix(y - z, featureSize);//条件矩阵右
Matrix resultMatrixLeft = new Matrix(z, 1);//结果矩阵左
Matrix resultMatrixRight = new Matrix(y - z, 1);//结果矩阵右
forestLeft.setConditionMatrix(conditionMatrixLeft);
@ -65,14 +66,14 @@ public class Forest extends Frequency {
for (int i = 0; i < y; i++) {
double nub = resultMatrix.getNumber(i, 0);//结果矩阵
if (nub > median) {//进入右森林并计算右森林结果矩阵方差
for (int j = 0; j < size; j++) {//进入右森林的条件矩阵
for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵
conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j));
}
resultRight[rightIndex] = nub;
resultMatrixRight.setNub(rightIndex, 0, nub);
rightIndex++;
} else {//进入左森林并计算左森林结果矩阵方差
for (int j = 0; j < size; j++) {//进入右森林的条件矩阵
for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵
conditionMatrixLeft.setNub(leftIndex, j, conditionMatrix.getNumber(i, j));
}
resultLeft[leftIndex] = nub;
@ -137,4 +138,12 @@ public class Forest extends Frequency {
public void setW(double[] w) {
this.w = w;
}
public Forest getForestLeft() {
return forestLeft;
}
public Forest getForestRight() {
return forestRight;
}
}

@ -2,6 +2,7 @@ package org.wlld.regressionForest;
import org.wlld.MatrixTools.Matrix;
import org.wlld.MatrixTools.MatrixOperation;
import org.wlld.tools.Frequency;
/**
* @param
@ -9,30 +10,44 @@ import org.wlld.MatrixTools.MatrixOperation;
* @Author LiDaPeng
* @Description
*/
public class RegressionForest {
public class RegressionForest extends Frequency {
private double[] w;
private Matrix conditionMatrix;//条件矩阵
private Matrix resultMatrix;//结果矩阵
private Forest forest;
private int featureNub;//特征数量
private int xIndex = 0;//记录插入位置
private double[] results;//结果数组
private double min;//结果最小值
private double max;//结果最大值
public RegressionForest(int size, int featureNub) throws Exception {//初始化
if (size > 0 && featureNub > 0) {
this.featureNub = featureNub;
w = new double[size];
results = new double[size];
conditionMatrix = new Matrix(size, featureNub);
resultMatrix = new Matrix(size, 1);
forest = new Forest(featureNub, 0.9);
forest.setW(w);
forest.setConditionMatrix(conditionMatrix);
forest.setResultMatrix(resultMatrix);
} else {
throw new Exception("size and featureNub too small");
}
}
public void getDist(double[] feature, double result) {//获取特征误差结果
}
public void insertFeature(double[] feature, double result) throws Exception {//插入数据
if (feature.length == featureNub) {
for (int i = 0; i < featureNub; i++) {
if (i < featureNub - 1) {
conditionMatrix.setNub(xIndex, i, feature[i]);
} else {
results[xIndex] = result;
conditionMatrix.setNub(xIndex, i, 1.0);
resultMatrix.setNub(xIndex, 0, result);
}
@ -43,14 +58,54 @@ public class RegressionForest {
}
}
public void start() throws Exception {//开始进行分段
if (forest != null) {
double[] limit = getLimit(results);
min = limit[0];
max = limit[1];
start(forest);
} else {
throw new Exception("rootForest is null");
}
}
private void start(Forest forest) throws Exception {
forest.cut();
Forest forestLeft = forest.getForestLeft();
Forest forestRight = forest.getForestRight();
if (forestLeft != null && forestRight != null) {
start(forestLeft);
start(forestRight);
}
}
public void regression() throws Exception {//开始进行回归
if (xIndex > 0) {
Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix);
for (int i = 0; i < ws.getX(); i++) {
w[i] = ws.getNumber(i, 0);
}
if (forest != null) {
regressionTree(forest);
} else {
throw new Exception("regression matrix size is zero");
throw new Exception("rootForest is null");
}
}
private void regressionTree(Forest forest) throws Exception {
regression(forest);
Forest forestLeft = forest.getForestLeft();
Forest forestRight = forest.getForestRight();
if (forestLeft != null && forestRight != null) {
regressionTree(forestLeft);
regressionTree(forestRight);
}
}
private void regression(Forest forest) throws Exception {
Matrix conditionMatrix = forest.getConditionMatrix();
Matrix resultMatrix = forest.getResultMatrix();
double[] w = forest.getW();
Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix);
for (int i = 0; i < ws.getX(); i++) {
w[i] = ws.getNumber(i, 0);
}
}
}
Loading…
Cancel
Save