!57 增加二叉树回归

Merge pull request !57 from 逐光/test
pull/57/MERGE
逐光 4 years ago committed by Gitee
commit 6bbd0213c2

@ -243,7 +243,7 @@ public class MatrixOperation {
double nub = 0; double nub = 0;
for (int i = 0; i < matrix.getX(); i++) { for (int i = 0; i < matrix.getX(); i++) {
for (int j = 0; j < matrix.getY(); j++) { for (int j = 0; j < matrix.getY(); j++) {
nub = ArithUtil.add(Math.pow(matrix.getNumber(i, j), 2), nub); nub = Math.pow(matrix.getNumber(i, j), 2) + nub;
} }
} }
return Math.sqrt(nub); return Math.sqrt(nub);

@ -22,6 +22,22 @@ public class Knn {//KNN分类器
featureMap.remove(type); featureMap.remove(type);
} }
public void revoke(int type, int nub) {//撤销一个类别最新的
List<Matrix> list = featureMap.get(type);
for (int i = 0; i < nub; i++) {
list.remove(list.size() - 1);
}
}
public int getNub(int type) {//获取该分类模型的数量
int nub = 0;
List<Matrix> list = featureMap.get(type);
if (list != null) {
nub = list.size();
}
return nub;
}
public void insertMatrix(Matrix vector, int tag) throws Exception { public void insertMatrix(Matrix vector, int tag) throws Exception {
if (vector.isVector() && vector.isRowVector()) { if (vector.isVector() && vector.isRowVector()) {
if (featureMap.size() == 0) { if (featureMap.size() == 0) {

@ -277,13 +277,13 @@ public class Watershed {
regionBodies.add(regionBody); regionBodies.add(regionBody);
} }
} }
// for (RegionBody regionBody : regionBodies) { for (RegionBody regionBody : regionBodies) {
// int minX = regionBody.getMinX(); int minX = regionBody.getMinX();
// int maxX = regionBody.getMaxX(); int maxX = regionBody.getMaxX();
// int minY = regionBody.getMinY(); int minY = regionBody.getMinY();
// int maxY = regionBody.getMaxY(); int maxY = regionBody.getMaxY();
// System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY);
// } }
return iou(regionBodies); return iou(regionBodies);
} }

File diff suppressed because it is too large Load Diff

@ -99,4 +99,23 @@ public abstract class Frequency {//统计频数
} }
return ArithUtil.div(my, all); return ArithUtil.div(my, all);
} }
public double[] getLimit(double[] m) {//获取数组中的最大值和最小值,最小值在前,最大值在后
double[] limit = new double[2];
double max = 0;
double min = -1;
int l = m.length;
for (int i = 0; i < l; i++) {
double nub = m[i];
if (min == -1 || nub < min) {
min = nub;
}
if (nub > max) {
max = nub;
}
}
limit[0] = min;
limit[1] = max;
return limit;
}
} }

@ -61,7 +61,7 @@ public class FoodTest {
Food food = templeConfig.getFood(); Food food = templeConfig.getFood();
// //
cutting.setMaxRain(360);//切割阈值 cutting.setMaxRain(360);//切割阈值
cutting.setTh(0.3); cutting.setTh(0.6);
cutting.setRegionNub(200); cutting.setRegionNub(200);
cutting.setMaxIou(2.0); cutting.setMaxIou(2.0);
//knn参数 //knn参数
@ -73,8 +73,8 @@ public class FoodTest {
//菜品识别实体类 //菜品识别实体类
food.setShrink(20);//缩紧像素 food.setShrink(20);//缩紧像素
food.setTimes(2);//聚类数据增强 food.setTimes(2);//聚类数据增强
food.setRowMark(0.1);//0.12 food.setRowMark(0.12);//0.12
food.setColumnMark(0.1);//0.25 food.setColumnMark(0.12);//0.25
food.setRegressionNub(20000); food.setRegressionNub(20000);
food.setTrayTh(0.08); food.setTrayTh(0.08);
templeConfig.setClassifier(Classifier.KNN); templeConfig.setClassifier(Classifier.KNN);
@ -99,18 +99,11 @@ public class FoodTest {
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg"); ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg");
operation.setTray(threeChannelMatrix); operation.setTray(threeChannelMatrix);
for (int i = 1; i <= 1; i++) { for (int i = 1; i <= 1; i++) {
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/a1.jpg"); ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/test.jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/b.jpg");
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/c.jpg");
operation.colorStudy(threeChannelMatrix1, 1, specificationsList); operation.colorStudy(threeChannelMatrix1, 1, specificationsList);
operation.colorStudy(threeChannelMatrix2, 2, specificationsList);
operation.colorStudy(threeChannelMatrix3, 3, specificationsList);
} }
// test2(templeConfig);
// minX==301,minY==430,maxX==854,maxY==920
// minX==497,minY==1090,maxX==994,maxY==1520
test2(templeConfig);
} }
public static void study() throws Exception { public static void study() throws Exception {

@ -0,0 +1,78 @@
package coverTest;
import org.wlld.randomForest.Tree;
import org.wlld.regressionForest.RegressionForest;
import java.util.*;
/**
* @param
* @DATA
* @Author LiDaPeng
* @Description
*/
public class ForestTest {
public static void main(String[] args) throws Exception {
test();
//int a = (int) (Math.log(4) / Math.log(2));//id22是第几层
//double a = Math.pow(2, 5) - 1; 第五层的第一个数
// System.out.println("a==" + a);
}
public static void test() throws Exception {//对分段回归进行测试
int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200);
regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2);
for (int i = 0; i < 1000; i++) {
double[] featureA = a.get(i);
double[] featureB = b.get(i);
double[] testA = new double[]{featureA[0], featureA[1]};
double[] testB = new double[]{featureB[0], featureB[1]};
regressionForest.insertFeature(testA, featureA[2]);
regressionForest.insertFeature(testB, featureB[2]);
}
regressionForest.startStudy();
///
List<double[]> a1 = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b1 = fun(0.7, 0.3, 0.1, size, 2, 2);
double sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = a1.get(i);
double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + dist;
}
double avs = sigma / size;
System.out.println("a误差" + avs);
// a误差0.0017585065712555645
// b误差0.00761733737464547
sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = b1.get(i);
double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + dist;
}
double avs2 = sigma / size;
System.out.println("b误差" + avs2);
}
public static List<double[]> fun(double w1, double w2, double w3, int size, int region, int index) {//生成假数据
List<double[]> list = new ArrayList<>();
Random random = new Random();
int nub = (index - 1) * 100;
double max = region * 100;
for (int i = 0; i < size; i++) {
double b = (double) (random.nextInt(100) + nub) / max;
double a = random.nextDouble();
double c = w1 * a + w2 * b + w3;
double[] data = new double[]{a, b, c};
list.add(data);
}
return list;
}
}
Loading…
Cancel
Save