增加分段回归

pull/57/head
lidapeng 5 years ago
parent de0722e587
commit 79ccfdf340

@ -27,6 +27,8 @@ public class Forest extends Frequency {
private double[] w; private double[] w;
private boolean isOldG = true;//是否使用老基 private boolean isOldG = true;//是否使用老基
private int oldGId = 0;//老基的id private int oldGId = 0;//老基的id
private Matrix matrixAll;//全矩阵
private double gNorm;//新维度的摸
public Forest(int featureSize, double shrinkParameter, Matrix pc) { public Forest(int featureSize, double shrinkParameter, Matrix pc) {
this.featureSize = featureSize; this.featureSize = featureSize;
@ -58,17 +60,17 @@ public class Forest extends Frequency {
return equalNub; return equalNub;
} }
private void findG() throws Exception {//寻找新的切入维度 private double[] findG() throws Exception {//寻找新的切入维度
// 先尝试从原有维度切入 // 先尝试从原有维度切入
int xSize = conditionMatrix.getX(); int xSize = conditionMatrix.getX();
int ySize = conditionMatrix.getY(); int ySize = conditionMatrix.getY();
Matrix matrix = new Matrix(xSize, ySize); matrixAll = new Matrix(xSize, ySize);
for (int i = 0; i < xSize; i++) { for (int i = 0; i < xSize; i++) {
for (int j = 0; j < ySize; j++) { for (int j = 0; j < ySize; j++) {
if (j < ySize - 1) { if (j < ySize - 1) {
matrix.setNub(i, j, conditionMatrix.getNumber(i, j)); matrixAll.setNub(i, j, conditionMatrix.getNumber(i, j));
} else { } else {
matrix.setNub(i, j, resultMatrix.getNumber(i, 0)); matrixAll.setNub(i, j, resultMatrix.getNumber(i, 0));
} }
} }
} }
@ -83,7 +85,7 @@ public class Forest extends Frequency {
g[j] = resultMatrix.getNumber(j, 0); g[j] = resultMatrix.getNumber(j, 0);
} }
} }
double var = variance(g);//计算方差 double var = dc(g);//计算方差
if (var > maxOld) { if (var > maxOld) {
maxOld = var; maxOld = var;
type = i; type = i;
@ -96,11 +98,11 @@ public class Forest extends Frequency {
double gNorm = MatrixOperation.getNorm(g); double gNorm = MatrixOperation.getNorm(g);
double[] var = new double[xSize]; double[] var = new double[xSize];
for (int j = 0; j < xSize; j++) { for (int j = 0; j < xSize; j++) {
Matrix parameter = matrix.getRow(j); Matrix parameter = matrixAll.getRow(j);
double dist = transG(g, parameter, gNorm); double dist = transG(g, parameter, gNorm);
var[j] = dist; var[j] = dist;
} }
double variance = variance(var); double variance = dc(var);
if (variance > max) { if (variance > max) {
max = variance; max = variance;
pc1 = g; pc1 = g;
@ -113,6 +115,7 @@ public class Forest extends Frequency {
isOldG = true; isOldG = true;
oldGId = type; oldGId = type;
} }
return findTwo(xSize);
} }
private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基 private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基
@ -121,19 +124,42 @@ public class Forest extends Frequency {
return innerProduct / gNorm; return innerProduct / gNorm;
} }
private double[] findTwo(int dataSize) throws Exception {
Matrix matrix;//创建一个列向量
double[] data = new double[dataSize];
if (isOldG) {//使用原有基
if (oldGId == featureSize - 1) {//从结果矩阵提取数据
matrix = resultMatrix;
} else {//从条件矩阵中提取数据
matrix = conditionMatrix.getColumn(oldGId);
}
//将数据塞入数组
for (int i = 0; i < dataSize; i++) {
data[i] = matrix.getNumber(i, 0);
}
} else {//使用转换基
int x = matrixAll.getX();
gNorm = MatrixOperation.getNorm(pc1);
for (int i = 0; i < x; i++) {
Matrix parameter = matrixAll.getRow(i);
double dist = transG(pc1, parameter, gNorm);
data[i] = dist;
}
}
Arrays.sort(data);//对数据进行排序
return data;
}
public void cut() throws Exception { public void cut() throws Exception {
int y = resultMatrix.getX(); int y = resultMatrix.getX();
if (y > 4) { if (y > 8) {
double[] dm = new double[y]; double[] dm = findG();
for (int i = 0; i < y; i++) {
dm[i] = resultMatrix.getNumber(i, 0);
}
Arrays.sort(dm);//排序 Arrays.sort(dm);//排序
int z = y / 2; int z = y / 2;
median = dm[z]; median = dm[z];
//检测中位数median有多少个一样的值 //检测中位数median有多少个一样的值
int equalNub = getEqualNub(median, dm); int equalNub = getEqualNub(median, dm);
//System.out.println("equalNub==" + equalNub + ",y==" + y); ////////////
forestLeft = new Forest(featureSize, shrinkParameter, pc); forestLeft = new Forest(featureSize, shrinkParameter, pc);
forestRight = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc);
Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左
@ -148,8 +174,20 @@ public class Forest extends Frequency {
int rightIndex = 0;//右矩阵添加行数 int rightIndex = 0;//右矩阵添加行数
double[] resultLeft = new double[z + equalNub]; double[] resultLeft = new double[z + equalNub];
double[] resultRight = new double[y - z - equalNub]; double[] resultRight = new double[y - z - equalNub];
//////
for (int i = 0; i < y; i++) { for (int i = 0; i < y; i++) {
double nub = resultMatrix.getNumber(i, 0);//结果矩阵 double nub;
if (isOldG) {//使用原有基
if (oldGId == featureSize - 1) {//从结果矩阵提取数据
nub = resultMatrix.getNumber(i, 0);
} else {//从条件矩阵中提取数据
nub = conditionMatrix.getNumber(i, oldGId);
}
} else {//使用新基
Matrix parameter = matrixAll.getRow(i);
nub = transG(pc1, parameter, gNorm);
}
//double nub = resultMatrix.getNumber(i, 0);//结果矩阵
if (nub > median) {//进入右森林并计算右森林结果矩阵方差 if (nub > median) {//进入右森林并计算右森林结果矩阵方差
for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵
conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j));

@ -26,7 +26,7 @@ public class RegressionForest extends Frequency {
private double min;//结果最小值 private double min;//结果最小值
private double max;//结果最大值 private double max;//结果最大值
private Matrix pc;//需要映射的基 private Matrix pc;//需要映射的基
private int cosSize = 10;//cos 分成几份 private int cosSize = 20;//cos 分成几份
public int getCosSize() { public int getCosSize() {
return cosSize; return cosSize;
@ -116,8 +116,8 @@ public class RegressionForest extends Frequency {
} }
double cosOne = 1.0D / cosSize; double cosOne = 1.0D / cosSize;
double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值 double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值
for (int i = 1; i < cosSize; i++) { for (int i = 0; i < cosSize - 1; i++) {
double cos = cosOne * i; double cos = cosOne * (i + 1);
ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1));
} }
int x = (cosSize - 1) * featureNub; int x = (cosSize - 1) * featureNub;

@ -20,8 +20,9 @@ public class ForestTest {
public static void test() throws Exception {//对分段回归进行测试 public static void test() throws Exception {//对分段回归进行测试
int size = 2000; int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); RegressionForest regressionForest = new RegressionForest(size, 3, 0.2);
List<double[]> a = fun(0.1, 0.2, 0.3, size); regressionForest.setCosSize(40);
List<double[]> b = fun(0.3, 0.2, 0.1, size); List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.3, 0.2, 0.1, size, 2, 2);
for (int i = 0; i < 1000; i++) { for (int i = 0; i < 1000; i++) {
double[] featureA = a.get(i); double[] featureA = a.get(i);
double[] featureB = b.get(i); double[] featureB = b.get(i);
@ -54,12 +55,14 @@ public class ForestTest {
} }
public static List<double[]> fun(double w1, double w2, double w3, int size) {//生成假数据 public static List<double[]> fun(double w1, double w2, double w3, int size, int region, int index) {//生成假数据
List<double[]> list = new ArrayList<>(); List<double[]> list = new ArrayList<>();
Random random = new Random(); Random random = new Random();
int nub = (index - 1) * 100;
double max = region * 100;
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
double b = (double) (random.nextInt(100) + nub) / max;
double a = random.nextDouble(); double a = random.nextDouble();
double b = random.nextDouble();
double c = w1 * a + w2 * b + w3; double c = w1 * a + w2 * b + w3;
double[] data = new double[]{a, b, c}; double[] data = new double[]{a, b, c};
list.add(data); list.add(data);

Loading…
Cancel
Save