From de0722e5877afc95c2a9d0cbcde074a754eb8b12 Mon Sep 17 00:00:00 2001 From: lidapeng <794757862@qq.com> Date: Mon, 14 Sep 2020 17:26:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=88=86=E6=AE=B5=E5=9B=9E?= =?UTF-8?q?=E5=BD=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/wlld/MatrixTools/MatrixOperation.java | 2 +- .../org/wlld/regressionForest/Forest.java | 101 ++++++++++-------- .../regressionForest/RegressionForest.java | 72 ++++++++++++- 3 files changed, 128 insertions(+), 47 deletions(-) diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index ec5f647..32e4c82 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -243,7 +243,7 @@ public class MatrixOperation { double nub = 0; for (int i = 0; i < matrix.getX(); i++) { for (int j = 0; j < matrix.getY(); j++) { - nub = ArithUtil.add(Math.pow(matrix.getNumber(i, j), 2), nub); + nub = Math.pow(matrix.getNumber(i, j), 2) + nub; } } return Math.sqrt(nub); diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java index c738ff5..fbe3c0f 100644 --- a/src/main/java/org/wlld/regressionForest/Forest.java +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -1,12 +1,10 @@ package org.wlld.regressionForest; import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; import org.wlld.tools.Frequency; -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.Random; +import java.util.*; /** @@ -24,13 +22,16 @@ public class Forest extends Frequency { private double resultVariance;//结果矩阵方差 private double median;//结果矩阵中位数 private double shrinkParameter;//方差收缩参数 - private Matrix pc;//需要映射的基 + private Matrix pc;//需要映射的基的集合 + private Matrix pc1;//需要映射的基 private double[] w; - private int cosSize = 10;//cos 分成几份 + private boolean isOldG = true;//是否使用老基 + private int oldGId = 0;//老基的id - public Forest(int featureSize, double shrinkParameter) { + public Forest(int featureSize, double shrinkParameter, Matrix pc) { this.featureSize = featureSize; this.shrinkParameter = shrinkParameter; + this.pc = pc; w = new double[featureSize]; } @@ -57,44 +58,22 @@ public class Forest extends Frequency { return equalNub; } - private void createG() throws Exception {//生成新基 - double[] cg = new double[featureSize - 1]; - Random random = new Random(); - double sigma = 0; - for (int i = 0; i < featureSize - 1; i++) { - double rm = random.nextDouble(); - cg[i] = rm; - sigma = sigma + Math.pow(rm, 2); - } - double cosOne = 1.0D / cosSize; - double[] ag = new double[cosSize - 1]; - for (int i = 1; i < cosSize; i++) { - double cos = cosOne * i; - ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); - } - int x = (cosSize - 1) * featureSize; - pc = new Matrix(x, featureSize); - for (int i = 0; i < featureSize; i++) { - Matrix matrix = new Matrix(ag.length, featureSize); - for (int j = 0; j < ag.length; j++) { - for (int k = 0; k < featureSize; k++) { - if (k != i) { - if (k < i) { - matrix.setNub(j, k, cg[k]); - } else { - matrix.setNub(j, k, cg[k - 1]); - } - } else { - matrix.setNub(j, k, ag[j]); - } + private void findG() throws Exception {//寻找新的切入维度 + // 先尝试从原有维度切入 + int xSize = conditionMatrix.getX(); + int ySize = conditionMatrix.getY(); + Matrix matrix = new Matrix(xSize, ySize); + for (int i = 0; i < xSize; i++) { + for (int j = 0; j < ySize; j++) { + if (j < ySize - 1) { + matrix.setNub(i, j, conditionMatrix.getNumber(i, j)); + } else { + matrix.setNub(i, j, resultMatrix.getNumber(i, 0)); } } } - } - - private void findG() throws Exception {//寻找新的切入维度 - // 先尝试从原有维度切入 - Map varMap = new HashMap<>();//保存原有维度方差 + double maxOld = 0; + int type = 0; for (int i = 0; i < featureSize; i++) { double[] g = new double[conditionMatrix.getX()]; for (int j = 0; j < g.length; j++) { @@ -105,9 +84,41 @@ public class Forest extends Frequency { } } double var = variance(g);//计算方差 - varMap.put(i, var); + if (var > maxOld) { + maxOld = var; + type = i; + } + } + int x = pc.getX(); + double max = 0; + for (int i = 0; i < x; i++) { + Matrix g = pc.getRow(i); + double gNorm = MatrixOperation.getNorm(g); + double[] var = new double[xSize]; + for (int j = 0; j < xSize; j++) { + Matrix parameter = matrix.getRow(j); + double dist = transG(g, parameter, gNorm); + var[j] = dist; + } + double variance = variance(var); + if (variance > max) { + max = variance; + pc1 = g; + } } + //找到非原始基最离散的新基 + if (max > maxOld) {//使用新基 + isOldG = false; + } else {//使用原有基 + isOldG = true; + oldGId = type; + } + } + private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基 + //先求内积 + double innerProduct = MatrixOperation.innerProduct(g, parameter); + return innerProduct / gNorm; } public void cut() throws Exception { @@ -123,8 +134,8 @@ public class Forest extends Frequency { //检测中位数median有多少个一样的值 int equalNub = getEqualNub(median, dm); //System.out.println("equalNub==" + equalNub + ",y==" + y); - forestLeft = new Forest(featureSize, shrinkParameter); - forestRight = new Forest(featureSize, shrinkParameter); + forestLeft = new Forest(featureSize, shrinkParameter, pc); + forestRight = new Forest(featureSize, shrinkParameter, pc); Matrix conditionMatrixLeft = new Matrix(z + equalNub, featureSize);//条件矩阵左 Matrix conditionMatrixRight = new Matrix(y - z - equalNub, featureSize);//条件矩阵右 Matrix resultMatrixLeft = new Matrix(z + equalNub, 1);//结果矩阵左 diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java index 1b73888..ae99e42 100644 --- a/src/main/java/org/wlld/regressionForest/RegressionForest.java +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -5,6 +5,9 @@ import org.wlld.MatrixTools.MatrixOperation; import org.wlld.tools.Frequency; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; /** * @param @@ -22,6 +25,16 @@ public class RegressionForest extends Frequency { private double[] results;//结果数组 private double min;//结果最小值 private double max;//结果最大值 + private Matrix pc;//需要映射的基 + private int cosSize = 10;//cos 分成几份 + + public int getCosSize() { + return cosSize; + } + + public void setCosSize(int cosSize) { + this.cosSize = cosSize; + } public RegressionForest(int size, int featureNub, double shrinkParameter) throws Exception {//初始化 if (size > 0 && featureNub > 0) { @@ -30,7 +43,8 @@ public class RegressionForest extends Frequency { results = new double[size]; conditionMatrix = new Matrix(size, featureNub); resultMatrix = new Matrix(size, 1); - forest = new Forest(featureNub, shrinkParameter); + createG(); + forest = new Forest(featureNub, shrinkParameter, pc); forest.setW(w); forest.setConditionMatrix(conditionMatrix); forest.setResultMatrix(resultMatrix); @@ -91,6 +105,62 @@ public class RegressionForest extends Frequency { } } + private void createG() throws Exception {//生成新基 + double[] cg = new double[featureNub - 1]; + Random random = new Random(); + double sigma = 0; + for (int i = 0; i < featureNub - 1; i++) { + double rm = random.nextDouble(); + cg[i] = rm; + sigma = sigma + Math.pow(rm, 2); + } + double cosOne = 1.0D / cosSize; + double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值 + for (int i = 1; i < cosSize; i++) { + double cos = cosOne * i; + ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); + } + int x = (cosSize - 1) * featureNub; + pc = new Matrix(x, featureNub); + for (int i = 0; i < featureNub; i++) {//遍历所有的固定基 + //以某个固定基摆动的所有新基集合的矩阵 + Matrix matrix = new Matrix(ag.length, featureNub); + for (int j = 0; j < ag.length; j++) { + for (int k = 0; k < featureNub; k++) { + if (k != i) { + if (k < i) { + matrix.setNub(j, k, cg[k]); + } else { + matrix.setNub(j, k, cg[k - 1]); + } + } else { + matrix.setNub(j, k, ag[j]); + } + } + } + //将一个固定基内摆动的新基都装到最大的集合内 + int index = (cosSize - 1) * i; + push(pc, matrix, index); + } + } + + //将两个矩阵从上到下进行合并 + private void push(Matrix mother, Matrix son, int index) throws Exception { + if (mother.getY() == son.getY()) { + int x = index + son.getX(); + int y = mother.getY(); + int start = 0; + for (int i = index; i < x; i++) { + for (int j = 0; j < y; j++) { + mother.setNub(i, j, son.getNumber(start, j)); + } + start++; + } + } else { + throw new Exception("matrix Y is not equals"); + } + } + public void insertFeature(double[] feature, double result) throws Exception {//插入数据 if (feature.length == featureNub - 1) { for (int i = 0; i < featureNub; i++) {