add ssd scipt

pull/891/head
zhaoting 5 years ago committed by yuzhenhua
parent 88215d0007
commit db79182005

@ -0,0 +1,88 @@
# SSD Example
## Description
SSD network based on MobileNetV2, with support for training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Dataset
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Download coco2017: [train2017](http://images.cocodataset.org/zips/train2017.zip), [val2017](http://images.cocodataset.org/zips/val2017.zip), [test2017](http://images.cocodataset.org/zips/test2017.zip), [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). Install pycocotool.
```
pip install Cython
pip install pycocotools
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```
└─coco2017
├── annotations # annotation jsons
├── train2017 # train dataset
└── val2017 # infer dataset
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
## Running the example
### Training
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
- Stand alone mode
```
python train.py --dataset coco
```
You can run ```python train.py -h``` to get more information.
- Distribute mode
```
sh run_distribute_train.sh 8 150 coco /data/hccl.json
```
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
You will get the loss value of each step as following:
```
epoch: 1 step: 455, loss is 5.8653416
epoch: 2 step: 455, loss is 5.4292373
epoch: 3 step: 455, loss is 5.458992
...
epoch: 148 step: 455, loss is 1.8340507
epoch: 149 step: 455, loss is 2.0876894
epoch: 150 step: 455, loss is 2.239692
```
### Evaluation
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
```
python eval.py --ckpt_path ssd.ckpt --dataset coco
```
You can run ```python eval.py -h``` to get more information.

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for SSD models."""
class ConfigSSD:
"""
Config parameters for SSD.
Examples:
ConfigSSD().
"""
IMG_SHAPE = [300, 300]
NUM_SSD_BOXES = 1917
NEG_PRE_POSITIVE = 3
MATCH_THRESHOLD = 0.5
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
SCALES = [21, 45, 99, 153, 207, 261, 315]
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
STEPS = (16, 32, 64, 100, 150, 300)
PRIOR_SCALING = (0.1, 0.2)
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
MINDRECORD_DIR = "MindRecord_COCO"
COCO_ROOT = "coco2017"
TRAIN_DATA_TYPE = "train2017"
VAL_DATA_TYPE = "val2017"
INSTANCES_SET = "annotations/instances_{}.json"
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
NUM_CLASSES = len(COCO_CLASSES)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,99 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation for SSD"""
import os
import argparse
import time
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from config import ConfigSSD
from util import metrics
def ssd_eval(dataset_path, ckpt_path):
"""SSD evaluation."""
ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False)
net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False)
print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
net.set_train(False)
i = 1.
total = ds.get_dataset_size()
start = time.time()
pred_data = []
print("\n========================================\n")
print("total images num: ", total)
print("Processing, please wait a moment.")
for data in ds.create_dict_iterator():
img_np = data['image']
image_shape = data['image_shape']
annotation = data['annotation']
output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx],
"annotation": annotation,
"image_shape": image_shape})
percent = round(i / total * 100, 2)
print(f' {str(percent)} [{i}/{total}]', end='\r')
i += 1
cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms')
mAP = metrics(pred_data)
print("\n========================================\n")
print(f"mAP: {mAP}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SSD evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
config = ConfigSSD()
prefix = "ssd_eval.mindrecord"
mindrecord_dir = config.MINDRECORD_DIR
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.COCO_ROOT):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("COCO_ROOT not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
print("Start Eval!")
ssd_eval(mindrecord_file, args_opt.checkpoint_path)

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH"
echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json"
echo "It is better to use absolute path."
echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script."
echo "=============================================================================================================="
# Before start distribute train, first create mindrecord files.
python train.py --only_create_dataset=1
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
export RANK_SIZE=$1
EPOCH_SIZE=$2
DATASET=$3
export MINDSPORE_HCCL_CONFIG_PATH=$4
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
python ../train.py \
--distribute=1 \
--lr=0.4 \
--dataset=$DATASET \
--device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
cd ../
done

@ -0,0 +1,176 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train SSD and get checkpoint files."""
import os
import math
import argparse
import numpy as np
import mindspore.nn as nn
from mindspore import context, Tensor
from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer
from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from config import ConfigSSD
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + (lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def init_net_param(network, initialize_mode='XavierUniform'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype()))
def main():
parser = argparse.ArgumentParser(description="SSD training")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
init()
rank = args_opt.device_id % device_num
else:
rank = 0
device_num = 1
print("Start create dataset!")
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is ssd.mindrecord0, 1, ... file_num.
config = ConfigSSD()
prefix = "ssd.mindrecord"
mindrecord_dir = config.MINDRECORD_DIR
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco":
if os.path.isdir(config.COCO_ROOT):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("COCO_ROOT not exits.")
else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else:
print("IMAGE_DIR or ANNO_PATH not exits.")
if not args_opt.only_create_dataset:
loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config)
net = SSDWithLossCell(ssd, config)
init_net_param(net)
# checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config)
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr,
warmup_epochs=max(args_opt.epoch_size // 20, 1),
total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size))
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
net = TrainingWrapper(net, opt, loss_scale)
if args_opt.checkpoint_path != "":
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
model = Model(net)
dataset_sink_mode = False
if args_opt.mode == "sink":
print("In sink mode, one epoch return a loss.")
dataset_sink_mode = True
print("Start train SSD, the first epoch will be slower because of the graph compilation.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
if __name__ == '__main__':
main()

@ -0,0 +1,208 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
from config import ConfigSSD
from dataset import ssd_bboxes_decode
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
bbox_pred = np.expand_dims(bbox_pred, axis=0)
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
pred_area = pred_w * pred_h
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
gt_area = gt_w * gt_h
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
iw = np.maximum(iw, 0)
ih = np.maximum(ih, 0)
intersection_area = iw * ih
union_area = pred_area + gt_area - intersection_area
union_area = np.maximum(union_area, np.finfo(float).eps)
iou = intersection_area * 1. / union_area
return iou
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def calc_ap(recall, precision):
"""Calculate AP."""
correct_recall = np.concatenate(([0.], recall, [1.]))
correct_precision = np.concatenate(([0.], precision, [0.]))
for i in range(correct_recall.size - 1, 0, -1):
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
return ap
def metrics(pred_data):
"""Calculate mAP of predicted bboxes."""
config = ConfigSSD()
num_classes = config.NUM_CLASSES
all_detections = [None for i in range(num_classes)]
all_pred_scores = [None for i in range(num_classes)]
all_annotations = [None for i in range(num_classes)]
average_precisions = {}
num = [0 for i in range(num_classes)]
accurate_num = [0 for i in range(num_classes)]
for sample in pred_data:
pred_boxes = sample['boxes']
boxes_scores = sample['box_scores']
annotation = sample['annotation']
image_shape = sample['image_shape']
annotation = np.squeeze(annotation, axis=0)
image_shape = np.squeeze(image_shape, axis=0)
pred_labels = np.argmax(boxes_scores, axis=-1)
index = np.nonzero(pred_labels)
pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape)
pred_boxes = pred_boxes.clip(0, 1)
boxes_scores = np.max(boxes_scores, axis=-1)
boxes_scores = boxes_scores[index]
pred_labels = pred_labels[index]
top_k = 50
for c in range(1, num_classes):
if len(pred_labels) >= 1:
class_box_scores = boxes_scores[pred_labels == c]
class_boxes = pred_boxes[pred_labels == c]
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
cmask = class_box_scores > 0.5
class_boxes = class_boxes[cmask]
class_box_scores = class_box_scores[cmask]
all_detections[c] = class_boxes
all_pred_scores[c] = class_box_scores
for c in range(1, num_classes):
if len(annotation) >= 1:
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
for c in range(1, num_classes):
false_positives = np.zeros((0,))
true_positives = np.zeros((0,))
scores = np.zeros((0,))
num_annotations = 0.0
annotations = all_annotations[c]
num_annotations += annotations.shape[0]
detections = all_detections[c]
pred_scores = all_pred_scores[c]
for index, detection in enumerate(detections):
scores = np.append(scores, pred_scores[index])
if len(annotations) >= 1:
IoUs = calc_iou(detection, annotations)
assigned_anno = np.argmax(IoUs)
max_overlap = IoUs[assigned_anno]
if max_overlap >= 0.5:
false_positives = np.append(false_positives, 0)
true_positives = np.append(true_positives, 1)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
if num_annotations == 0:
if c not in average_precisions.keys():
average_precisions[c] = 0
continue
accurate_num[c] = 1
indices = np.argsort(-scores)
false_positives = false_positives[indices]
true_positives = true_positives[indices]
false_positives = np.cumsum(false_positives)
true_positives = np.cumsum(true_positives)
recall = true_positives * 1. / num_annotations
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
average_precision = calc_ap(recall, precision)
if c not in average_precisions.keys():
average_precisions[c] = average_precision
else:
average_precisions[c] += average_precision
num[c] += 1
count = 0
for key in average_precisions:
if num[key] != 0:
count += (average_precisions[key] / num[key])
mAP = count * 1. / accurate_num.count(1)
return mAP

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save