diff --git a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java index ec5f647..32e4c82 100644 --- a/src/main/java/org/wlld/MatrixTools/MatrixOperation.java +++ b/src/main/java/org/wlld/MatrixTools/MatrixOperation.java @@ -243,7 +243,7 @@ public class MatrixOperation { double nub = 0; for (int i = 0; i < matrix.getX(); i++) { for (int j = 0; j < matrix.getY(); j++) { - nub = ArithUtil.add(Math.pow(matrix.getNumber(i, j), 2), nub); + nub = Math.pow(matrix.getNumber(i, j), 2) + nub; } } return Math.sqrt(nub); diff --git a/src/main/java/org/wlld/imageRecognition/border/Knn.java b/src/main/java/org/wlld/imageRecognition/border/Knn.java index 68fa9a4..81a0f48 100644 --- a/src/main/java/org/wlld/imageRecognition/border/Knn.java +++ b/src/main/java/org/wlld/imageRecognition/border/Knn.java @@ -22,6 +22,22 @@ public class Knn {//KNN分类器 featureMap.remove(type); } + public void revoke(int type, int nub) {//撤销一个类别最新的 + List list = featureMap.get(type); + for (int i = 0; i < nub; i++) { + list.remove(list.size() - 1); + } + } + + public int getNub(int type) {//获取该分类模型的数量 + int nub = 0; + List list = featureMap.get(type); + if (list != null) { + nub = list.size(); + } + return nub; + } + public void insertMatrix(Matrix vector, int tag) throws Exception { if (vector.isVector() && vector.isRowVector()) { if (featureMap.size() == 0) { diff --git a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java index 56958cb..6612b9f 100644 --- a/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java +++ b/src/main/java/org/wlld/imageRecognition/segmentation/Watershed.java @@ -277,13 +277,13 @@ public class Watershed { regionBodies.add(regionBody); } } -// for (RegionBody regionBody : regionBodies) { -// int minX = regionBody.getMinX(); -// int maxX = regionBody.getMaxX(); -// int minY = regionBody.getMinY(); -// int maxY = regionBody.getMaxY(); -// System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); -// } + for (RegionBody regionBody : regionBodies) { + int minX = regionBody.getMinX(); + int maxX = regionBody.getMaxX(); + int minY = regionBody.getMinY(); + int maxY = regionBody.getMaxY(); + System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); + } return iou(regionBodies); } diff --git a/src/main/java/org/wlld/regressionForest/Forest.java b/src/main/java/org/wlld/regressionForest/Forest.java new file mode 100644 index 0000000..f3f12b4 --- /dev/null +++ b/src/main/java/org/wlld/regressionForest/Forest.java @@ -0,0 +1,303 @@ +package org.wlld.regressionForest; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.tools.Frequency; + +import java.util.*; + + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 分段切割容器 + */ +public class Forest extends Frequency { + private Matrix conditionMatrix;//条件矩阵 + private Matrix resultMatrix;//结果矩阵 + private Forest forestLeft;//左森林 + private Forest forestRight;//右森林 + private int featureSize; + private double resultVariance;//结果矩阵方差 + private double median;//结果矩阵中位数 + private double shrinkParameter;//方差收缩参数 + private Matrix pc;//需要映射的基的集合 + private Matrix pc1;//需要映射的基 + private double[] w; + private boolean isOldG = true;//是否使用老基 + private int oldGId = 0;//老基的id + private Matrix matrixAll;//全矩阵 + private double gNorm;//新维度的摸 + private Forest father;//父级 + private Map forestMap;//尽头列表 + private int id;//本节点的id + private boolean isRemove = false;//是否已经被移除了 + private boolean notRemovable = false;//不可移除 + private int minGrain;//最小粒度 + + public Forest(int featureSize, double shrinkParameter, Matrix pc, Map forestMap + , int id, int minGrain) { + this.featureSize = featureSize; + this.shrinkParameter = shrinkParameter; + this.pc = pc; + w = new double[featureSize]; + this.forestMap = forestMap; + this.id = id; + this.minGrain = minGrain; + } + + public double getMedian() { + return median; + } + + public double getResultVariance() { + return resultVariance; + } + + public void setResultVariance(double resultVariance) { + this.resultVariance = resultVariance; + } + + private double[] findG() throws Exception {//寻找新的切入维度 + // 先尝试从原有维度切入 + int xSize = conditionMatrix.getX(); + int ySize = conditionMatrix.getY(); + matrixAll = new Matrix(xSize, ySize); + for (int i = 0; i < xSize; i++) { + for (int j = 0; j < ySize; j++) { + if (j < ySize - 1) { + matrixAll.setNub(i, j, conditionMatrix.getNumber(i, j)); + } else { + matrixAll.setNub(i, j, resultMatrix.getNumber(i, 0)); + } + } + } + double maxOld = 0; + int type = 0; + for (int i = 0; i < featureSize; i++) { + double[] g = new double[conditionMatrix.getX()]; + for (int j = 0; j < g.length; j++) { + if (i < featureSize - 1) { + g[j] = conditionMatrix.getNumber(j, i); + } else { + g[j] = resultMatrix.getNumber(j, 0); + } + } + double var = dc(g);//计算方差 + if (var > maxOld) { + maxOld = var; + type = i; + } + } + int x = pc.getX(); + double max = 0; + for (int i = 0; i < x; i++) { + Matrix g = pc.getRow(i); + double gNorm = MatrixOperation.getNorm(g); + double[] var = new double[xSize]; + for (int j = 0; j < xSize; j++) { + Matrix parameter = matrixAll.getRow(j); + double dist = transG(g, parameter, gNorm); + var[j] = dist; + } + double variance = dc(var); + if (variance > max) { + max = variance; + pc1 = g; + } + } + //找到非原始基最离散的新基: + if (max > maxOld) {//使用新基 + isOldG = false; + } else {//使用原有基 + isOldG = true; + oldGId = type; + } + return findTwo(xSize); + } + + private double transG(Matrix g, Matrix parameter, double gNorm) throws Exception {//将数据映射到新基 + //先求内积 + double innerProduct = MatrixOperation.innerProduct(g, parameter); + return innerProduct / gNorm; + } + + private double[] findTwo(int dataSize) throws Exception { + Matrix matrix;//创建一个列向量 + double[] data = new double[dataSize]; + if (isOldG) {//使用原有基 + if (oldGId == featureSize - 1) {//从结果矩阵提取数据 + matrix = resultMatrix; + } else {//从条件矩阵中提取数据 + matrix = conditionMatrix.getColumn(oldGId); + } + //将数据塞入数组 + for (int i = 0; i < dataSize; i++) { + data[i] = matrix.getNumber(i, 0); + } + } else {//使用转换基 + int x = matrixAll.getX(); + gNorm = MatrixOperation.getNorm(pc1); + for (int i = 0; i < x; i++) { + Matrix parameter = matrixAll.getRow(i); + double dist = transG(pc1, parameter, gNorm); + data[i] = dist; + } + } + Arrays.sort(data);//对数据进行排序 + return data; + } + + private double getDist(double[] data, double[] w) { + int len = data.length; + double sigma = 0; + for (int i = 0; i < len; i++) { + double sub = data[i] - w[i]; + sigma = sigma + Math.pow(sub, 2); + } + return sigma / len; + } + + public void pruning() {//进行后剪枝,跟父级进行比较 + if (!notRemovable) { + Forest fatherForest = this.getFather(); + double[] fatherW = fatherForest.getW(); + double sub = getDist(w, fatherW); + if (sub < shrinkParameter) {//需要剪枝,通知父级 + fatherForest.getSonMessage(true, id); + isRemove = true; + //System.out.println("剪枝id==" + id + ",sub==" + sub + ",th==" + shrinkParameter); + } else {//通知父级,不需要剪枝,并将父级改为不可移除 + fatherForest.getSonMessage(false, id); + } + } + } + + public void getSonMessage(boolean isPruning, int myId) {//进行剪枝 + if (isPruning) {//剪枝 + if (myId == id * 2) {//左节点 + forestLeft = null; + } else {//右节点 + forestRight = null; + } + } else {//不剪枝,将自己变为不可剪枝状态 + notRemovable = true; + } + } + + public void cut() throws Exception { + int y = resultMatrix.getX(); + if (y > minGrain) { + double[] dm = findG(); + int z = y / 2; + median = dm[z]; + int rightNub = 0; + int leftNub = 0; + for (int i = 0; i < dm.length; i++) { + if (dm[i] > median) { + rightNub++; + } else { + leftNub++; + } + } + int leftId = 2 * id; + int rightId = leftId + 1; + //System.out.println("id:" + id + ",size:" + dm.length); + forestMap.put(id, this); + forestLeft = new Forest(featureSize, shrinkParameter, pc, forestMap, leftId, minGrain); + forestRight = new Forest(featureSize, shrinkParameter, pc, forestMap, rightId, minGrain); + forestRight.setFather(this); + forestLeft.setFather(this); + Matrix conditionMatrixLeft = new Matrix(leftNub, featureSize);//条件矩阵左 + Matrix conditionMatrixRight = new Matrix(rightNub, featureSize);//条件矩阵右 + Matrix resultMatrixLeft = new Matrix(leftNub, 1);//结果矩阵左 + Matrix resultMatrixRight = new Matrix(rightNub, 1);//结果矩阵右 + forestLeft.setConditionMatrix(conditionMatrixLeft); + forestLeft.setResultMatrix(resultMatrixLeft); + forestRight.setConditionMatrix(conditionMatrixRight); + forestRight.setResultMatrix(resultMatrixRight); + int leftIndex = 0;//左矩阵添加行数 + int rightIndex = 0;//右矩阵添加行数 + for (int i = 0; i < y; i++) { + double nub; + if (isOldG) {//使用原有基 + nub = matrixAll.getNumber(i, oldGId); + } else {//使用新基 + Matrix parameter = matrixAll.getRow(i); + nub = transG(pc1, parameter, gNorm); + } + if (nub > median) {//进入右森林并计算右森林结果矩阵方差 + for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 + conditionMatrixRight.setNub(rightIndex, j, conditionMatrix.getNumber(i, j)); + } + resultMatrixRight.setNub(rightIndex, 0, resultMatrix.getNumber(i, 0)); + rightIndex++; + } else {//进入左森林并计算左森林结果矩阵方差 + for (int j = 0; j < featureSize; j++) {//进入右森林的条件矩阵 + conditionMatrixLeft.setNub(leftIndex, j, conditionMatrix.getNumber(i, j)); + } + resultMatrixLeft.setNub(leftIndex, 0, resultMatrix.getNumber(i, 0)); + leftIndex++; + } + } + //分区完成 + } + } + + public Matrix getConditionMatrix() { + return conditionMatrix; + } + + public void setConditionMatrix(Matrix conditionMatrix) { + this.conditionMatrix = conditionMatrix; + } + + public Matrix getResultMatrix() { + return resultMatrix; + } + + public void setResultMatrix(Matrix resultMatrix) { + this.resultMatrix = resultMatrix; + } + + public double[] getW() { + return w; + } + + public void setW(double[] w) { + this.w = w; + } + + public Forest getForestLeft() { + return forestLeft; + } + + public Forest getForestRight() { + return forestRight; + } + + public Forest getFather() { + return father; + } + + public void setFather(Forest father) { + this.father = father; + } + + public boolean isRemove() { + return isRemove; + } + + public void setRemove(boolean remove) { + isRemove = remove; + } + + public boolean isNotRemovable() { + return notRemovable; + } + + public void setNotRemovable(boolean notRemovable) { + this.notRemovable = notRemovable; + } +} diff --git a/src/main/java/org/wlld/regressionForest/RegressionForest.java b/src/main/java/org/wlld/regressionForest/RegressionForest.java new file mode 100644 index 0000000..95b8ba3 --- /dev/null +++ b/src/main/java/org/wlld/regressionForest/RegressionForest.java @@ -0,0 +1,274 @@ +package org.wlld.regressionForest; + +import org.wlld.MatrixTools.Matrix; +import org.wlld.MatrixTools.MatrixOperation; +import org.wlld.tools.Frequency; + +import java.util.*; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description 回归森林 + */ +public class RegressionForest extends Frequency { + private double[] w; + private Matrix conditionMatrix;//条件矩阵 + private Matrix resultMatrix;//结果矩阵 + private Forest forest; + private int featureNub;//特征数量 + private int xIndex = 0;//记录插入位置 + private double[] results;//结果数组 + private double min;//结果最小值 + private double max;//结果最大值 + private Matrix pc;//需要映射的基 + private int cosSize = 20;//cos 分成几份 + private TreeMap forestMap = new TreeMap<>();//节点列表 + + public int getCosSize() { + return cosSize; + } + + public void setCosSize(int cosSize) { + this.cosSize = cosSize; + } + + public RegressionForest(int size, int featureNub, double shrinkParameter, int minGrain) throws Exception {//初始化 + if (size > 0 && featureNub > 0) { + this.featureNub = featureNub; + w = new double[featureNub]; + results = new double[size]; + conditionMatrix = new Matrix(size, featureNub); + resultMatrix = new Matrix(size, 1); + createG(); + forest = new Forest(featureNub, shrinkParameter, pc, forestMap, 1, minGrain); + forestMap.put(1, forest); + forest.setW(w); + forest.setConditionMatrix(conditionMatrix); + forest.setResultMatrix(resultMatrix); + } else { + throw new Exception("size and featureNub too small"); + } + } + + public double getDist(double[] feature, double result) {//获取特征误差结果 + Forest forestFinish; + if (result <= min) {//直接找下边界区域 + forestFinish = getLimitRegion(forest, false); + } else if (result >= max) {//直接找到上边界区域 + forestFinish = getLimitRegion(forest, true); + } else { + forestFinish = getRegion(forest, result); + } + //计算误差 + double[] w = forestFinish.getW(); + double sigma = 0; + for (int i = 0; i < w.length; i++) { + double nub; + if (i < w.length - 1) { + nub = w[i] * feature[i]; + } else { + nub = w[i]; + } + sigma = sigma + nub; + } + return Math.abs(result - sigma); + } + + private Forest getRegion(Forest forest, double result) { + double median = forest.getMedian(); + if (result > median && forest.getForestRight() != null) {//向右走 + forest = forest.getForestRight(); + } else if (result <= median && forest.getForestLeft() != null) {//向左走 + forest = forest.getForestLeft(); + } else { + return forest; + } + return getRegion(forest, result); + } + + private Forest getLimitRegion(Forest forest, boolean isMax) { + Forest forestSon; + if (isMax) { + forestSon = forest.getForestRight(); + } else { + forestSon = forest.getForestLeft(); + } + if (forestSon != null) { + return getLimitRegion(forestSon, isMax); + } else { + return forest; + } + } + + private void createG() throws Exception {//生成新基 + double[] cg = new double[featureNub - 1]; + Random random = new Random(); + double sigma = 0; + for (int i = 0; i < featureNub - 1; i++) { + double rm = random.nextDouble(); + cg[i] = rm; + sigma = sigma + Math.pow(rm, 2); + } + double cosOne = 1.0D / cosSize; + double[] ag = new double[cosSize - 1];//装一个维度内所有角度的余弦值 + for (int i = 0; i < cosSize - 1; i++) { + double cos = cosOne * (i + 1); + ag[i] = Math.sqrt(sigma / (1 / Math.pow(cos, 2) - 1)); + } + int x = (cosSize - 1) * featureNub; + pc = new Matrix(x, featureNub); + for (int i = 0; i < featureNub; i++) {//遍历所有的固定基 + //以某个固定基摆动的所有新基集合的矩阵 + Matrix matrix = new Matrix(ag.length, featureNub); + for (int j = 0; j < ag.length; j++) { + for (int k = 0; k < featureNub; k++) { + if (k != i) { + if (k < i) { + matrix.setNub(j, k, cg[k]); + } else { + matrix.setNub(j, k, cg[k - 1]); + } + } else { + matrix.setNub(j, k, ag[j]); + } + } + } + //将一个固定基内摆动的新基都装到最大的集合内 + int index = (cosSize - 1) * i; + push(pc, matrix, index); + } + } + + //将两个矩阵从上到下进行合并 + private void push(Matrix mother, Matrix son, int index) throws Exception { + if (mother.getY() == son.getY()) { + int x = index + son.getX(); + int y = mother.getY(); + int start = 0; + for (int i = index; i < x; i++) { + for (int j = 0; j < y; j++) { + mother.setNub(i, j, son.getNumber(start, j)); + } + start++; + } + } else { + throw new Exception("matrix Y is not equals"); + } + } + + public void insertFeature(double[] feature, double result) throws Exception {//插入数据 + if (feature.length == featureNub - 1) { + for (int i = 0; i < featureNub; i++) { + if (i < featureNub - 1) { + conditionMatrix.setNub(xIndex, i, feature[i]); + } else { + results[xIndex] = result; + conditionMatrix.setNub(xIndex, i, 1.0); + resultMatrix.setNub(xIndex, 0, result); + } + } + xIndex++; + } else { + throw new Exception("feature length is not equals"); + } + } + + public void startStudy() throws Exception {//开始进行分段 + if (forest != null) { + //计算方差 + forest.setResultVariance(variance(results)); + double[] limit = getLimit(results); + min = limit[0]; + max = limit[1]; + start(forest); + //进行回归 + regression(); + //进行剪枝 + pruning(); + } else { + throw new Exception("rootForest is null"); + } + } + + private void start(Forest forest) throws Exception { + forest.cut(); + Forest forestLeft = forest.getForestLeft(); + Forest forestRight = forest.getForestRight(); + if (forestLeft != null && forestRight != null) { + start(forestLeft); + start(forestRight); + } + } + + private void pruning() throws Exception {//剪枝 + //先获取当前最大id + int max = forestMap.lastKey(); + int layersNub = (int) (Math.log(max) / Math.log(2));//当前的层数 + int lastMin = (int) Math.pow(2, layersNub);//最后一层最小的id + if (layersNub > 1) {//先遍历最后一层 + for (Map.Entry entry : forestMap.entrySet()) { + if (entry.getKey() >= lastMin) { + Forest forest = entry.getValue(); + forest.pruning(); + } + } + } + //每一层从下到上进行剪枝 + for (int i = layersNub - 1; i > 0; i--) { + int min = (int) Math.pow(2, i);//最后一层最小的id + int maxNub = (int) Math.pow(2, i + 1); + for (Map.Entry entry : forestMap.entrySet()) { + int key = entry.getKey(); + if (key >= min && key < maxNub) {//在范围内,进行剪枝 + entry.getValue().pruning(); + } else if (key >= maxNub) { + break; + } + } + } + //遍历所有节点,将删除的节点移除 + List list = new ArrayList<>(); + for (Map.Entry entry : forestMap.entrySet()) { + int key = entry.getKey(); + Forest forest = entry.getValue(); + if (forest.isRemove()) { + list.add(key); + } + } + for (int key : list) { + forestMap.remove(key); + } + } + + private void regression() throws Exception {//开始进行回归 + if (forest != null) { + regressionTree(forest); + } else { + throw new Exception("rootForest is null"); + } + } + + private void regressionTree(Forest forest) throws Exception { + regression(forest); + Forest forestLeft = forest.getForestLeft(); + Forest forestRight = forest.getForestRight(); + if (forestLeft != null && forestRight != null) { + regressionTree(forestLeft); + regressionTree(forestRight); + } + + } + + private void regression(Forest forest) throws Exception {//对分段进行线性回归 + Matrix conditionMatrix = forest.getConditionMatrix(); + Matrix resultMatrix = forest.getResultMatrix(); + Matrix ws = MatrixOperation.getLinearRegression(conditionMatrix, resultMatrix); + double[] w = forest.getW(); + for (int i = 0; i < ws.getX(); i++) { + w[i] = ws.getNumber(i, 0); + } + + } +} \ No newline at end of file diff --git a/src/main/java/org/wlld/tools/Frequency.java b/src/main/java/org/wlld/tools/Frequency.java index d7bb2cd..8bddcd5 100644 --- a/src/main/java/org/wlld/tools/Frequency.java +++ b/src/main/java/org/wlld/tools/Frequency.java @@ -99,4 +99,23 @@ public abstract class Frequency {//统计频数 } return ArithUtil.div(my, all); } + + public double[] getLimit(double[] m) {//获取数组中的最大值和最小值,最小值在前,最大值在后 + double[] limit = new double[2]; + double max = 0; + double min = -1; + int l = m.length; + for (int i = 0; i < l; i++) { + double nub = m[i]; + if (min == -1 || nub < min) { + min = nub; + } + if (nub > max) { + max = nub; + } + } + limit[0] = min; + limit[1] = max; + return limit; + } } diff --git a/src/test/java/coverTest/FoodTest.java b/src/test/java/coverTest/FoodTest.java index cf2ea69..b0ea1d6 100644 --- a/src/test/java/coverTest/FoodTest.java +++ b/src/test/java/coverTest/FoodTest.java @@ -61,7 +61,7 @@ public class FoodTest { Food food = templeConfig.getFood(); // cutting.setMaxRain(360);//切割阈值 - cutting.setTh(0.3); + cutting.setTh(0.6); cutting.setRegionNub(200); cutting.setMaxIou(2.0); //knn参数 @@ -73,8 +73,8 @@ public class FoodTest { //菜品识别实体类 food.setShrink(20);//缩紧像素 food.setTimes(2);//聚类数据增强 - food.setRowMark(0.1);//0.12 - food.setColumnMark(0.1);//0.25 + food.setRowMark(0.12);//0.12 + food.setColumnMark(0.12);//0.25 food.setRegressionNub(20000); food.setTrayTh(0.08); templeConfig.setClassifier(Classifier.KNN); @@ -99,18 +99,11 @@ public class FoodTest { ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg"); operation.setTray(threeChannelMatrix); for (int i = 1; i <= 1; i++) { - ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/a1.jpg"); - ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/b.jpg"); - ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/c.jpg"); + ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/test.jpg"); operation.colorStudy(threeChannelMatrix1, 1, specificationsList); - operation.colorStudy(threeChannelMatrix2, 2, specificationsList); - operation.colorStudy(threeChannelMatrix3, 3, specificationsList); } - -// minX==301,minY==430,maxX==854,maxY==920 -// minX==497,minY==1090,maxX==994,maxY==1520 - test2(templeConfig); + // test2(templeConfig); } public static void study() throws Exception { diff --git a/src/test/java/coverTest/ForestTest.java b/src/test/java/coverTest/ForestTest.java new file mode 100644 index 0000000..e3dacd5 --- /dev/null +++ b/src/test/java/coverTest/ForestTest.java @@ -0,0 +1,78 @@ +package coverTest; + +import org.wlld.randomForest.Tree; +import org.wlld.regressionForest.RegressionForest; + +import java.util.*; + +/** + * @param + * @DATA + * @Author LiDaPeng + * @Description + */ +public class ForestTest { + public static void main(String[] args) throws Exception { + test(); + //int a = (int) (Math.log(4) / Math.log(2));//id22是第几层 + //double a = Math.pow(2, 5) - 1; 第五层的第一个数 + // System.out.println("a==" + a); + + } + + public static void test() throws Exception {//对分段回归进行测试 + int size = 2000; + RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200); + regressionForest.setCosSize(40); + List a = fun(0.1, 0.2, 0.3, size, 2, 1); + List b = fun(0.7, 0.3, 0.1, size, 2, 2); + for (int i = 0; i < 1000; i++) { + double[] featureA = a.get(i); + double[] featureB = b.get(i); + double[] testA = new double[]{featureA[0], featureA[1]}; + double[] testB = new double[]{featureB[0], featureB[1]}; + regressionForest.insertFeature(testA, featureA[2]); + regressionForest.insertFeature(testB, featureB[2]); + } + regressionForest.startStudy(); + /// + List a1 = fun(0.1, 0.2, 0.3, size, 2, 1); + List b1 = fun(0.7, 0.3, 0.1, size, 2, 2); + double sigma = 0; + for (int i = 0; i < 1000; i++) { + double[] feature = a1.get(i); + double[] test = new double[]{feature[0], feature[1]}; + double dist = regressionForest.getDist(test, feature[2]); + sigma = sigma + dist; + } + double avs = sigma / size; + System.out.println("a误差:" + avs); +// a误差:0.0017585065712555645 +// b误差:0.00761733737464547 + sigma = 0; + for (int i = 0; i < 1000; i++) { + double[] feature = b1.get(i); + double[] test = new double[]{feature[0], feature[1]}; + double dist = regressionForest.getDist(test, feature[2]); + sigma = sigma + dist; + } + double avs2 = sigma / size; + System.out.println("b误差:" + avs2); + + } + + public static List fun(double w1, double w2, double w3, int size, int region, int index) {//生成假数据 + List list = new ArrayList<>(); + Random random = new Random(); + int nub = (index - 1) * 100; + double max = region * 100; + for (int i = 0; i < size; i++) { + double b = (double) (random.nextInt(100) + nub) / max; + double a = random.nextDouble(); + double c = w1 * a + w2 * b + w3; + double[] data = new double[]{a, b, c}; + list.add(data); + } + return list; + } +}