!51 增加多元线性回归可调参数

Merge pull request !51 from 逐光/test
pull/51/MERGE
逐光 5 years ago committed by Gitee
commit 1d90bb3ea9

@ -64,13 +64,13 @@ public class Convolution extends Frequency {
} }
public List<List<Double>> kAvg(ThreeChannelMatrix threeMatrix, int sqNub public List<List<Double>> kAvg(ThreeChannelMatrix threeMatrix, int sqNub
, int regionSize) throws Exception { , int regionSize, TempleConfig templeConfig) throws Exception {
RGBSort rgbSort = new RGBSort(); RGBSort rgbSort = new RGBSort();
List<List<Double>> features = new ArrayList<>(); List<List<Double>> features = new ArrayList<>();
List<ThreeChannelMatrix> threeChannelMatrixList = regionThreeChannelMatrix(threeMatrix, regionSize); List<ThreeChannelMatrix> threeChannelMatrixList = regionThreeChannelMatrix(threeMatrix, regionSize);
for (ThreeChannelMatrix threeChannelMatrix : threeChannelMatrixList) { for (ThreeChannelMatrix threeChannelMatrix : threeChannelMatrixList) {
List<Double> feature = new ArrayList<>(); List<Double> feature = new ArrayList<>();
MeanClustering meanClustering = new MeanClustering(sqNub); MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig);
Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB(); Matrix matrixB = threeChannelMatrix.getMatrixB();
@ -165,7 +165,7 @@ public class Convolution extends Frequency {
//System.out.println(matrixBD.getString()); //System.out.println(matrixBD.getString());
} }
public List<Double> getCenterColor(ThreeChannelMatrix threeChannelMatrix, int poolSize, int sqNub) throws Exception { public List<Double> getCenterColor(ThreeChannelMatrix threeChannelMatrix, int poolSize, int sqNub, TempleConfig templeConfig) throws Exception {
Matrix matrixR = threeChannelMatrix.getMatrixR(); Matrix matrixR = threeChannelMatrix.getMatrixR();
Matrix matrixG = threeChannelMatrix.getMatrixG(); Matrix matrixG = threeChannelMatrix.getMatrixG();
Matrix matrixB = threeChannelMatrix.getMatrixB(); Matrix matrixB = threeChannelMatrix.getMatrixB();
@ -175,7 +175,7 @@ public class Convolution extends Frequency {
RGBSort rgbSort = new RGBSort(); RGBSort rgbSort = new RGBSort();
int x = matrixR.getX(); int x = matrixR.getX();
int y = matrixR.getY(); int y = matrixR.getY();
MeanClustering meanClustering = new MeanClustering(sqNub); MeanClustering meanClustering = new MeanClustering(sqNub, templeConfig);
for (int i = 0; i < x; i++) { for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) { for (int j = 0; j < y; j++) {
double[] color = new double[]{matrixR.getNumber(i, j), matrixG.getNumber(i, j), matrixB.getNumber(i, j)}; double[] color = new double[]{matrixR.getNumber(i, j), matrixG.getNumber(i, j), matrixB.getNumber(i, j)};

@ -16,8 +16,9 @@ public class MeanClustering {
return matrices; return matrices;
} }
public MeanClustering(int speciesQuantity) { public MeanClustering(int speciesQuantity, TempleConfig templeConfig) {
this.speciesQuantity = speciesQuantity;//聚类的数量 this.speciesQuantity = speciesQuantity;//聚类的数量
size = templeConfig.getFood().getRegressionNub();
} }
public void setColor(double[] color) throws Exception { public void setColor(double[] color) throws Exception {

@ -111,7 +111,7 @@ public class Operation {//进行计算
int times = templeConfig.getFood().getTimes(); int times = templeConfig.getFood().getTimes();
for (int i = 0; i < times; i++) { for (int i = 0; i < times; i++) {
List<Double> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), List<Double> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(),
templeConfig.getFeatureNub()); templeConfig.getFeatureNub(), templeConfig);
if (templeConfig.isShowLog()) { if (templeConfig.isShowLog()) {
System.out.println(tag + ":" + feature); System.out.println(tag + ":" + feature);
} }
@ -178,7 +178,7 @@ public class Operation {//进行计算
ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize); ThreeChannelMatrix threeChannelMatrix1 = convolution.getRegionMatrix(threeChannelMatrix, minX, minY, xSize, ySize);
//convolution.filtering(threeChannelMatrix1);//光照过滤 //convolution.filtering(threeChannelMatrix1);//光照过滤
List<Double> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(), List<Double> feature = convolution.getCenterColor(threeChannelMatrix1, templeConfig.getPoolSize(),
templeConfig.getFeatureNub()); templeConfig.getFeatureNub(), templeConfig);
if (templeConfig.isShowLog()) { if (templeConfig.isShowLog()) {
System.out.println(feature); System.out.println(feature);
} }
@ -272,7 +272,7 @@ public class Operation {//进行计算
CoverBody coverBody = new CoverBody(); CoverBody coverBody = new CoverBody();
Map<Integer, Double> tag = new HashMap<>(); Map<Integer, Double> tag = new HashMap<>();
tag.put(entry.getKey(), 1.0); tag.put(entry.getKey(), 1.0);
List<List<Double>> lists = convolution.kAvg(entry.getValue(), sqNub, regionSize); List<List<Double>> lists = convolution.kAvg(entry.getValue(), sqNub, regionSize, templeConfig);
size = lists.size(); size = lists.size();
coverBody.setFeature(lists); coverBody.setFeature(lists);
coverBody.setTag(tag); coverBody.setTag(tag);
@ -296,7 +296,7 @@ public class Operation {//进行计算
if (templeConfig.getStudyPattern() == StudyPattern.Cover_Pattern) { if (templeConfig.getStudyPattern() == StudyPattern.Cover_Pattern) {
Map<Integer, Double> coverMap = new HashMap<>(); Map<Integer, Double> coverMap = new HashMap<>();
Map<Integer, Integer> typeNub = new HashMap<>(); Map<Integer, Integer> typeNub = new HashMap<>();
List<List<Double>> lists = convolution.kAvg(matrix, sqNub, regionSize); List<List<Double>> lists = convolution.kAvg(matrix, sqNub, regionSize, templeConfig);
//特征塞入容器完毕 //特征塞入容器完毕
int size = lists.size(); int size = lists.size();
int all = 0; int all = 0;

@ -38,6 +38,7 @@ public class Watershed {
private double columnMark;//列过滤 private double columnMark;//列过滤
private List<Specifications> specifications;//过滤候选区参数 private List<Specifications> specifications;//过滤候选区参数
private List<RgbRegression> trayBody;//托盘参数 private List<RgbRegression> trayBody;//托盘参数
private double trayTh;
public Watershed(ThreeChannelMatrix matrix, List<Specifications> specifications, TempleConfig templeConfig) throws Exception { public Watershed(ThreeChannelMatrix matrix, List<Specifications> specifications, TempleConfig templeConfig) throws Exception {
if (matrix != null && specifications != null && specifications.size() > 0) { if (matrix != null && specifications != null && specifications.size() > 0) {
@ -61,6 +62,7 @@ public class Watershed {
xSize = this.matrix.getX() / regionNub; xSize = this.matrix.getX() / regionNub;
ySize = this.matrix.getY() / regionNub; ySize = this.matrix.getY() / regionNub;
maxIou = templeConfig.getCutting().getMaxIou(); maxIou = templeConfig.getCutting().getMaxIou();
trayTh = templeConfig.getFood().getTrayTh();
// System.out.println("xSize===" + xSize + ",ysize===" + ySize); // System.out.println("xSize===" + xSize + ",ysize===" + ySize);
rainfallMap = new Matrix(this.matrix.getX(), this.matrix.getY()); rainfallMap = new Matrix(this.matrix.getX(), this.matrix.getY());
regionMap = new Matrix(regionNub, regionNub); regionMap = new Matrix(regionNub, regionNub);
@ -77,7 +79,7 @@ public class Watershed {
matrixB.getNumber(x, y) / 255}; matrixB.getNumber(x, y) / 255};
for (RgbRegression rgbRegression : trayBody) { for (RgbRegression rgbRegression : trayBody) {
double dist = rgbRegression.getDisError(rgb); double dist = rgbRegression.getDisError(rgb);
if (dist < 0) { if (dist < trayTh) {
isTray = true; isTray = true;
break; break;
} }
@ -274,15 +276,15 @@ public class Watershed {
regionBodies.add(regionBody); regionBodies.add(regionBody);
} }
} }
for (RegionBody regionBody : regionBodies) { // for (RegionBody regionBody : regionBodies) {
int minX = regionBody.getMinX(); // int minX = regionBody.getMinX();
int maxX = regionBody.getMaxX(); // int maxX = regionBody.getMaxX();
int minY = regionBody.getMinY(); // int minY = regionBody.getMinY();
int maxY = regionBody.getMaxY(); // int maxY = regionBody.getMaxY();
System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY); // System.out.println("minX==" + minX + ",minY==" + minY + ",maxX==" + maxX + ",maxY==" + maxY);
} // }
//return iou(regionBodies); return iou(regionBodies);
return regionBodies; //return regionBodies;
} }
private List<RegionBody> iou(List<RegionBody> regionBodies) { private List<RegionBody> iou(List<RegionBody> regionBodies) {

@ -18,6 +18,15 @@ public class Food {
private double columnMark = 0.25;//列痕迹过滤 private double columnMark = 0.25;//列痕迹过滤
private List<RgbRegression> trayBody = new ArrayList<>();//托盘实体参数 private List<RgbRegression> trayBody = new ArrayList<>();//托盘实体参数
private int regressionNub = 10000;//回归次数 private int regressionNub = 10000;//回归次数
private double trayTh = 0.1;//托盘回归阈值
public double getTrayTh() {
return trayTh;
}
public void setTrayTh(double trayTh) {
this.trayTh = trayTh;
}
public int getRegressionNub() { public int getRegressionNub() {
return regressionNub; return regressionNub;

Loading…
Cancel
Save