diff --git a/src/main/java/org/wlld/imageRecognition/Operation.java b/src/main/java/org/wlld/imageRecognition/Operation.java index 9250b41..f6354c8 100644 --- a/src/main/java/org/wlld/imageRecognition/Operation.java +++ b/src/main/java/org/wlld/imageRecognition/Operation.java @@ -12,6 +12,7 @@ import org.wlld.nerveEntity.SensoryNerve; import org.wlld.tools.ArithUtil; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -83,7 +84,7 @@ public class Operation {//进行计算 } //边框检测+识别分类 - public void lookWithPosition(Matrix matrix, long eventId) throws Exception { + public Map> lookWithPosition(Matrix matrix, long eventId) throws Exception { Frame frame = templeConfig.getFrame(); if (templeConfig.isHavePosition() && frame != null && frame.isReady()) { List frameBodies = convolution.getRegion(matrix, frame); @@ -104,21 +105,53 @@ public class Operation {//进行计算 //进入神经网络判断 intoNerve(eventId, list, templeConfig.getSensoryNerves(), false, null); } - toPositon(frameBodies, frame.getWidth(), frame.getHeight()); + return toPositon(frameBodies, frame.getWidth(), frame.getHeight()); } else if (templeConfig.getStudyPattern() == StudyPattern.Accuracy_Pattern) { + return null; + } else { + throw new Exception("wrong model"); } } else { throw new Exception("position not study or frame is not ready"); } } - private void toPositon(List frameBodies, int width, int height) throws Exception {//把分类都拿出来 + private Map> toPositon(List frameBodies, int width, int height) throws Exception {//把分类都拿出来 for (FrameBody frameBody : frameBodies) { if (frameBody.getPoint() > templeConfig.getTh()) {//存在一个识别分类 getBox(frameBody, width, height); } } + return result(frameBodies); + } + + private Map> result(List frameBodies) { + Map> map = new HashMap<>(); + for (FrameBody frameBody : frameBodies) { + if (frameBody.getPoint() > templeConfig.getTh()) {//存在一个识别分类 + int id = frameBody.getId(); + if (map.containsKey(id)) { + List frameBodies1 = map.get(id); + boolean isHere = false; + for (FrameBody frameBody1 : frameBodies1) { + double iou = getIou(frameBody1, frameBody); + if (iou > templeConfig.getIouTh()) { + isHere = true; + break; + } + } + if (!isHere) { + frameBodies1.add(frameBody); + } + } else { + List frameBodyList = new ArrayList<>(); + frameBodyList.add(frameBody); + map.put(id, frameBodyList); + } + } + } + return map; } //获得预测边框 diff --git a/src/main/java/org/wlld/imageRecognition/TempleConfig.java b/src/main/java/org/wlld/imageRecognition/TempleConfig.java index 25ec7cc..71b2496 100644 --- a/src/main/java/org/wlld/imageRecognition/TempleConfig.java +++ b/src/main/java/org/wlld/imageRecognition/TempleConfig.java @@ -34,6 +34,15 @@ public class TempleConfig { private ImageBack imageBack = new ImageBack();//边框图像回调 private double th = 0.6;//标准阈值 private boolean boxReady = false;//边框已经学习完毕 + private double iouTh = 0.5;//IOU阈值 + + public double getIouTh() { + return iouTh; + } + + public void setIouTh(double iouTh) { + this.iouTh = iouTh; + } public boolean isBoxReady() { return boxReady;