!57 增加二叉树回归

Merge pull request !57 from 逐光/test
pull/57/MERGE
逐光 4 years ago committed by Gitee
commit 6bbd0213c2

@ -243,7 +243,7 @@ public class MatrixOperation {
double nub = 0;
for (int i = 0; i < matrix.getX(); i++) {
for (int j = 0; j < matrix.getY(); j++) {
nub = ArithUtil.add(Math.pow(matrix.getNumber(i, j), 2), nub);
nub = Math.pow(matrix.getNumber(i, j), 2) + nub;
}
}
return Math.sqrt(nub);

@ -22,6 +22,22 @@ public class Knn {//KNN分类器
featureMap.remove(type);
}
public void revoke(int type, int nub) {//撤销一个类别最新的
List<Matrix> list = featureMap.get(type);
for (int i = 0; i < nub; i++) {
list.remove(list.size() - 1);
}
}
public int getNub(int type) {//获取该分类模型的数量
int nub = 0;
List<Matrix> list = featureMap.get(type);
if (list != null) {
nub = list.size();
}
return nub;
}
public void insertMatrix(Matrix vector, int tag) throws Exception {
if (vector.isVector() && vector.isRowVector()) {
if (featureMap.size() == 0) {

@ -277,13 +277,13 @@ public class Watershed {
regionBodies.add(regionBody);
}
}
// for (RegionBody regionBody : regionBodies) {
// int minX = regionBody.getMinX();
// int maxX = regionBody.getMaxX();
// int minY = regionBody.getMinY();
// int maxY = regionBody.getMaxY();
// System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY);
// }
for (RegionBody regionBody : regionBodies) {
int minX = regionBody.getMinX();
int maxX = regionBody.getMaxX();
int minY = regionBody.getMinY();
int maxY = regionBody.getMaxY();
System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY);
}
return iou(regionBodies);
}

File diff suppressed because it is too large Load Diff

@ -99,4 +99,23 @@ public abstract class Frequency {//统计频数
}
return ArithUtil.div(my, all);
}
public double[] getLimit(double[] m) {//获取数组中的最大值和最小值,最小值在前,最大值在后
double[] limit = new double[2];
double max = 0;
double min = -1;
int l = m.length;
for (int i = 0; i < l; i++) {
double nub = m[i];
if (min == -1 || nub < min) {
min = nub;
}
if (nub > max) {
max = nub;
}
}
limit[0] = min;
limit[1] = max;
return limit;
}
}

@ -61,7 +61,7 @@ public class FoodTest {
Food food = templeConfig.getFood();
//
cutting.setMaxRain(360);//切割阈值
cutting.setTh(0.3);
cutting.setTh(0.6);
cutting.setRegionNub(200);
cutting.setMaxIou(2.0);
//knn参数
@ -73,8 +73,8 @@ public class FoodTest {
//菜品识别实体类
food.setShrink(20);//缩紧像素
food.setTimes(2);//聚类数据增强
food.setRowMark(0.1);//0.12
food.setColumnMark(0.1);//0.25
food.setRowMark(0.12);//0.12
food.setColumnMark(0.12);//0.25
food.setRegressionNub(20000);
food.setTrayTh(0.08);
templeConfig.setClassifier(Classifier.KNN);
@ -99,18 +99,11 @@ public class FoodTest {
ThreeChannelMatrix threeChannelMatrix = picture.getThreeMatrix("/Users/lidapeng/Desktop/myDocument/d.jpg");
operation.setTray(threeChannelMatrix);
for (int i = 1; i <= 1; i++) {
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/a1.jpg");
ThreeChannelMatrix threeChannelMatrix2 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/b.jpg");
ThreeChannelMatrix threeChannelMatrix3 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/c.jpg");
ThreeChannelMatrix threeChannelMatrix1 = picture.getThreeMatrix("/Users/lidapeng/Desktop/test/test.jpg");
operation.colorStudy(threeChannelMatrix1, 1, specificationsList);
operation.colorStudy(threeChannelMatrix2, 2, specificationsList);
operation.colorStudy(threeChannelMatrix3, 3, specificationsList);
}
// minX==301,minY==430,maxX==854,maxY==920
// minX==497,minY==1090,maxX==994,maxY==1520
test2(templeConfig);
// test2(templeConfig);
}
public static void study() throws Exception {

@ -0,0 +1,78 @@
package coverTest;
import org.wlld.randomForest.Tree;
import org.wlld.regressionForest.RegressionForest;
import java.util.*;
/**
* @param
* @DATA
* @Author LiDaPeng
* @Description
*/
public class ForestTest {
public static void main(String[] args) throws Exception {
test();
//int a = (int) (Math.log(4) / Math.log(2));//id22是第几层
//double a = Math.pow(2, 5) - 1; 第五层的第一个数
// System.out.println("a==" + a);
}
public static void test() throws Exception {//对分段回归进行测试
int size = 2000;
RegressionForest regressionForest = new RegressionForest(size, 3, 0.01, 200);
regressionForest.setCosSize(40);
List<double[]> a = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b = fun(0.7, 0.3, 0.1, size, 2, 2);
for (int i = 0; i < 1000; i++) {
double[] featureA = a.get(i);
double[] featureB = b.get(i);
double[] testA = new double[]{featureA[0], featureA[1]};
double[] testB = new double[]{featureB[0], featureB[1]};
regressionForest.insertFeature(testA, featureA[2]);
regressionForest.insertFeature(testB, featureB[2]);
}
regressionForest.startStudy();
///
List<double[]> a1 = fun(0.1, 0.2, 0.3, size, 2, 1);
List<double[]> b1 = fun(0.7, 0.3, 0.1, size, 2, 2);
double sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = a1.get(i);
double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + dist;
}
double avs = sigma / size;
System.out.println("a误差" + avs);
// a误差0.0017585065712555645
// b误差0.00761733737464547
sigma = 0;
for (int i = 0; i < 1000; i++) {
double[] feature = b1.get(i);
double[] test = new double[]{feature[0], feature[1]};
double dist = regressionForest.getDist(test, feature[2]);
sigma = sigma + dist;
}
double avs2 = sigma / size;
System.out.println("b误差" + avs2);
}
public static List<double[]> fun(double w1, double w2, double w3, int size, int region, int index) {//生成假数据
List<double[]> list = new ArrayList<>();
Random random = new Random();
int nub = (index - 1) * 100;
double max = region * 100;
for (int i = 0; i < size; i++) {
double b = (double) (random.nextInt(100) + nub) / max;
double a = random.nextDouble();
double c = w1 * a + w2 * b + w3;
double[] data = new double[]{a, b, c};
list.add(data);
}
return list;
}
}
Loading…
Cancel
Save