diff --git a/model_zoo/official/cv/ssd/postprocess.py b/model_zoo/official/cv/ssd/postprocess.py index 3e0212ede9..05c570817f 100644 --- a/model_zoo/official/cv/ssd/postprocess.py +++ b/model_zoo/official/cv/ssd/postprocess.py @@ -25,6 +25,7 @@ batch_size = 1 parser = argparse.ArgumentParser(description="ssd_mobilenet_v1_fpn inference") parser.add_argument("--result_path", type=str, required=True, help="result files path.") parser.add_argument("--img_path", type=str, required=True, help="image file path.") +parser.add_argument("--drop", action="store_true", help="drop iscrowd images or not.") args = parser.parse_args() def get_imgSize(file_name): @@ -33,15 +34,43 @@ def get_imgSize(file_name): def get_result(result_path, img_id_file_path): anno_json = os.path.join(config.coco_root, config.instances_set.format(config.val_data_type)) + + if args.drop: + from pycocotools.coco import COCO + train_cls = config.classes + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + files = os.listdir(img_id_file_path) pred_data = [] for file in files: img_ids_name = file.split('.')[0] img_id = int(np.squeeze(img_ids_name)) + if args.drop: + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + annos = [] + iscrowd = False + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + iscrowd = iscrowd or label["iscrowd"] + if class_name in train_cls: + x_min, x_max = bbox[0], bbox[0] + bbox[2] + y_min, y_max = bbox[1], bbox[1] + bbox[3] + annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) + if iscrowd or (not annos): + continue + img_size = get_imgSize(os.path.join(img_id_file_path, file)) image_shape = np.array([img_size[1], img_size[0]]) - result_path_0 = os.path.join(result_path, img_ids_name + "_0.bin") result_path_1 = os.path.join(result_path, img_ids_name + "_1.bin") diff --git a/model_zoo/official/cv/ssd/scripts/run_infer_310.sh b/model_zoo/official/cv/ssd/scripts/run_infer_310.sh index 7e77f276a5..59be2e3cd9 100644 --- a/model_zoo/official/cv/ssd/scripts/run_infer_310.sh +++ b/model_zoo/official/cv/ssd/scripts/run_infer_310.sh @@ -78,7 +78,7 @@ function infer() function cal_acc() { - python3.7 ../postprocess.py --result_path=./result_Files --img_path=$data_path &> acc.log & + python3.7 ../postprocess.py --result_path=./result_Files --img_path=$data_path --drop &> acc.log & } compile_app