diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index fbe3c0f..9127010 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -27,6 +27,8 @@ public class Forest extends Frequency { private double[] w; private boolean isOldG = true;//是否使用老基 private int oldGId = 0;//老基的id + private Matrix matrixAll;//全矩阵 + private double gNorm;//新维度的摸 public Forest(int featureSize, double shrinkParameter, Matrix pc) { this.featureSize = featureSize; @@ -58,17 +60,17 @@ public class Forest extends Frequency { return equalNub; } - private void findG() throws Exception {//寻找新的切入维度 + private double[] findG() throws Exception {//寻找新的切入维度 // 先尝试从原有维度切入 int xSize = conditionMatrix.getX(); int ySize = conditionMatrix.getY(); - Matrix matrix = new Matrix(xSize, ySize); + matrixAll = new Matrix(xSize, ySize); for (int i = 0; i < xSize; i++) { for (int j = 0; j < ySize; j++) { if (j < ySize - 1) { - matrix.setNub(i, j, conditionMatrix.getNumber(i, j)); + matrixAll.setNub(i, j, conditionMatrix.getNumber(i, j)); } else { - matrix.setNub(i, j, resultMatrix.getNumber(i, 0)); + matrixAll.setNub(i, j, resultMatrix.getNumber(i, 0)); } } } @@ -83,7 +85,7 @@ public class Forest extends Frequency { g[j] = resultMatrix.getNumber(j, 0); } } - double var = variance(g);//计算方差 + double var = dc(g);//计算方差 if (var > maxOld) { maxOld = var; type = i; @@ -96,11 +98,11 @@ public class Forest extends Frequency { double gNorm = MatrixOperation.getNorm(g); double[] var = new double[xSize]; for (int j = 0; j < xSize; j++) { - Matrix parameter = matrix.getRow(j); + Matrix parameter = matrixAll.getRow(j); double dist = transG(g, parameter, gNorm); var[j] = dist; } - double variance = variance(var); + double variance = dc(var); if (variance > max) { max = variance; pc1 = g; @@ -113,6 +115,7 @@ public class Forest extends Frequency { isOldG = true; oldGId = type; } + return findTwo(xSize); } private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基 @@ -121,19 +124,42 @@ public class Forest extends Frequency { return innerProduct / gNorm; } + private double[] findTwo(int dataSize) throws Exception { + Matrix matrix;//创建一个列向量 + double[] data = new double[dataSize]; + if (isOldG) {//使用原有基 + if (oldGId == featureSize - 1) {//从结果矩阵提取数据 + matrix = resultMatrix; + } else {//从条件矩阵中提取数据 + matrix = conditionMatrix.getColumn(oldGId); + } + //将数据塞入数组 + for (int i = 0; i < dataSize; i++) { + data[i] = matrix.getNumber(i, 0); + } + } else {//使用转换基 + int x = matrixAll.getX(); + gNorm = MatrixOperation.getNorm(pc1); + for (int i = 0; i < x; i++) { + Matrix parameter = matrixAll.getRow(i); + double dist = transG(pc1, parameter, gNorm); + data[i] = dist; + } + } + Arrays.sort(data);//对数据进行排序 + return data; + } + public void cut() throws Exception { int y = resultMatrix.getX(); - if (y > 4) { - double[] dm = new double[y]; - for (int i = 0; i < y; i++) { - dm[i] = resultMatrix.getNumber(i, 0); - } + if (y > 8) { + double[] dm = findG(); Arrays.sort(dm);//排序 int z = y / 2; median = dm[z]; //检测中位数median有多少个一样的值 int equalNub = getEqualNub(median, dm); - //System.out.println("equalNub==" + equalNub + ",y==" + y); + //////////// forestLeft = new Forest(featureSize, shrinkParameter, pc); forestRight = new Forest(featureSize, shrinkParameter, pc); Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 @@ -148,8 +174,20 @@ public class Forest extends Frequency { int rightIndex = 0;//右矩阵添加行数 double[] resultLeft = new double[z + equalNub]; double[] resultRight = new double[y - z - equalNub]; + ////// for (int i = 0; i < y; i++) { - double nub = resultMatrix.getNumber(i, 0);//结果矩阵 + double nub; + if (isOldG) {//使用原有基 + if (oldGId == featureSize - 1) {//从结果矩阵提取数据 + nub = resultMatrix.getNumber(i, 0); + } else {//从条件矩阵中提取数据 + nub = conditionMatrix.getNumber(i, oldGId); + } + } else {//使用新基 + Matrix parameter = matrixAll.getRow(i); + nub = transG(pc1, parameter, gNorm); + } + //double nub = resultMatrix.getNumber(i, 0);//结果矩阵 if (nub > median) {//进入右森林并计算右森林结果矩阵方差 for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index ae99e42..00d9a6c 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -26,7 +26,7 @@ public class RegressionForest extends Frequency { private double min;//结果最小值 private double max;//结果最大值 private Matrix pc;//需要映射的基 - private int cosSize = 10;//cos 分成几份 + private int cosSize = 20;//cos 分成几份 public int getCosSize() { return cosSize; @@ -116,8 +116,8 @@ public class RegressionForest extends Frequency { } double cosOne = 1.0D / cosSize; double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值 - for (int i = 1; i < cosSize; i++) { - double cos = cosOne * i; + for (int i = 0; i < cosSize - 1; i++) { + double cos = cosOne * (i + 1); ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); } int x = (cosSize - 1) * featureNub; diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java index a01b63d..66cd738 100644 --- a/src/test/java/coverTest/ForestTest.java +++ b/src/test/java/coverTest/ForestTest.java @@ -20,8 +20,9 @@ public class ForestTest { public static void test() throws Exception {//对分段回归进行测试 int size = 2000; RegressionForest regressionForest = new RegressionForest(size, 3, 0.2); - List a = fun(0.1, 0.2, 0.3, size); - List b = fun(0.3, 0.2, 0.1, size); + regressionForest.setCosSize(40); + List a = fun(0.1, 0.2, 0.3, size, 2, 1); + List b = fun(0.3, 0.2, 0.1, size, 2, 2); for (int i = 0; i < 1000; i++) { double[] featureA = a.get(i); double[] featureB = b.get(i); @@ -54,12 +55,14 @@ public class ForestTest { } - public static List fun(double w1, double w2, double w3, int size) {//生成假数据 + public static List fun(double w1, double w2, double w3, int size, int region, int index) {//生成假数据 List list = new ArrayList<>(); Random random = new Random(); + int nub = (index - 1) * 100; + double max = region * 100; for (int i = 0; i < size; i++) { + double b = (double) (random.nextInt(100) + nub) / max; double a = random.nextDouble(); - double b = random.nextDouble(); double c = w1 * a + w2 * b + w3; double[] data = new double[]{a, b, c}; list.add(data);