parent
88215d0007
commit
db79182005
@ -0,0 +1,88 @@
|
||||
# SSD Example
|
||||
|
||||
## Description
|
||||
|
||||
SSD network based on MobileNetV2, with support for training and evaluation.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Dataset
|
||||
|
||||
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
|
||||
|
||||
1. If coco dataset is used. **Select dataset to coco when run script.**
|
||||
Download coco2017: [train2017](http://images.cocodataset.org/zips/train2017.zip), [val2017](http://images.cocodataset.org/zips/val2017.zip), [test2017](http://images.cocodataset.org/zips/test2017.zip), [annotations](http://images.cocodataset.org/annotations/annotations_trainval2017.zip). Install pycocotool.
|
||||
|
||||
```
|
||||
pip install Cython
|
||||
|
||||
pip install pycocotools
|
||||
```
|
||||
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
|
||||
|
||||
|
||||
```
|
||||
└─coco2017
|
||||
├── annotations # annotation jsons
|
||||
├── train2017 # train dataset
|
||||
└── val2017 # infer dataset
|
||||
```
|
||||
|
||||
2. If your own dataset is used. **Select dataset to other when run script.**
|
||||
Organize the dataset infomation into a TXT file, each row in the file is as follows:
|
||||
|
||||
```
|
||||
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
|
||||
```
|
||||
|
||||
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
|
||||
|
||||
|
||||
## Running the example
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
|
||||
|
||||
|
||||
- Stand alone mode
|
||||
|
||||
```
|
||||
python train.py --dataset coco
|
||||
|
||||
```
|
||||
|
||||
You can run ```python train.py -h``` to get more information.
|
||||
|
||||
|
||||
- Distribute mode
|
||||
|
||||
```
|
||||
sh run_distribute_train.sh 8 150 coco /data/hccl.json
|
||||
```
|
||||
|
||||
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
|
||||
|
||||
You will get the loss value of each step as following:
|
||||
|
||||
```
|
||||
epoch: 1 step: 455, loss is 5.8653416
|
||||
epoch: 2 step: 455, loss is 5.4292373
|
||||
epoch: 3 step: 455, loss is 5.458992
|
||||
...
|
||||
epoch: 148 step: 455, loss is 1.8340507
|
||||
epoch: 149 step: 455, loss is 2.0876894
|
||||
epoch: 150 step: 455, loss is 2.239692
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
|
||||
|
||||
```
|
||||
python eval.py --ckpt_path ssd.ckpt --dataset coco
|
||||
```
|
||||
|
||||
You can run ```python eval.py -h``` to get more information.
|
@ -0,0 +1,64 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Config parameters for SSD models."""
|
||||
|
||||
|
||||
class ConfigSSD:
|
||||
"""
|
||||
Config parameters for SSD.
|
||||
|
||||
Examples:
|
||||
ConfigSSD().
|
||||
"""
|
||||
IMG_SHAPE = [300, 300]
|
||||
NUM_SSD_BOXES = 1917
|
||||
NEG_PRE_POSITIVE = 3
|
||||
MATCH_THRESHOLD = 0.5
|
||||
|
||||
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
|
||||
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
|
||||
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
|
||||
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
|
||||
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
|
||||
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
|
||||
SCALES = [21, 45, 99, 153, 207, 261, 315]
|
||||
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
|
||||
STEPS = (16, 32, 64, 100, 150, 300)
|
||||
PRIOR_SCALING = (0.1, 0.2)
|
||||
|
||||
|
||||
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
|
||||
MINDRECORD_DIR = "MindRecord_COCO"
|
||||
COCO_ROOT = "coco2017"
|
||||
TRAIN_DATA_TYPE = "train2017"
|
||||
VAL_DATA_TYPE = "val2017"
|
||||
INSTANCES_SET = "annotations/instances_{}.json"
|
||||
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush')
|
||||
NUM_CLASSES = len(COCO_CLASSES)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,99 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for SSD"""
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2
|
||||
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
|
||||
from config import ConfigSSD
|
||||
from util import metrics
|
||||
|
||||
def ssd_eval(dataset_path, ckpt_path):
|
||||
"""SSD evaluation."""
|
||||
|
||||
ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False)
|
||||
net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = 1.
|
||||
total = ds.get_dataset_size()
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator():
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
annotation = data['annotation']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"annotation": annotation,
|
||||
"image_shape": image_shape})
|
||||
percent = round(i / total * 100, 2)
|
||||
|
||||
print(f' {str(percent)} [{i}/{total}]', end='\r')
|
||||
i += 1
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
mAP = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='SSD evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--checkpoint_path", type=str, required=True, help="Checkpoint file path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
|
||||
|
||||
config = ConfigSSD()
|
||||
prefix = "ssd_eval.mindrecord"
|
||||
mindrecord_dir = config.MINDRECORD_DIR
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
if os.path.isdir(config.COCO_ROOT):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("COCO_ROOT not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
|
||||
print("Start Eval!")
|
||||
ssd_eval(mindrecord_file, args_opt.checkpoint_path)
|
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH"
|
||||
echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script."
|
||||
echo "=============================================================================================================="
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
python train.py --only_create_dataset=1
|
||||
|
||||
echo "After running the scipt, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
DATASET=$3
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$4
|
||||
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp *.py ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python ../train.py \
|
||||
--distribute=1 \
|
||||
--lr=0.4 \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,176 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""train SSD and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from mindspore.train import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
|
||||
from config import ConfigSSD
|
||||
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_end + (lr_max - lr_end) * \
|
||||
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
def init_net_param(network, initialize_mode='XavierUniform'):
|
||||
"""Init the parameters in net."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype()))
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SSD training")
|
||||
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
|
||||
"Mindrecord, default is false.")
|
||||
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
|
||||
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_task_sink=True, enable_loop_sink=True, enable_mem_reuse=True)
|
||||
|
||||
if args_opt.distribute:
|
||||
device_num = args_opt.device_num
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
||||
device_num=device_num)
|
||||
init()
|
||||
rank = args_opt.device_id % device_num
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
print("Start create dataset!")
|
||||
|
||||
# It will generate mindrecord file in args_opt.mindrecord_dir,
|
||||
# and the file name is ssd.mindrecord0, 1, ... file_num.
|
||||
|
||||
config = ConfigSSD()
|
||||
prefix = "ssd.mindrecord"
|
||||
mindrecord_dir = config.MINDRECORD_DIR
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
if os.path.isdir(config.COCO_ROOT):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("COCO_ROOT not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", True, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
|
||||
dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config)
|
||||
net = SSDWithLossCell(ssd, config)
|
||||
init_net_param(net)
|
||||
|
||||
# checkpoint
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config)
|
||||
|
||||
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr,
|
||||
warmup_epochs=max(args_opt.epoch_size // 20, 1),
|
||||
total_epochs=args_opt.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
|
||||
if args_opt.checkpoint_path != "":
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
|
||||
|
||||
model = Model(net)
|
||||
dataset_sink_mode = False
|
||||
if args_opt.mode == "sink":
|
||||
print("In sink mode, one epoch return a loss.")
|
||||
dataset_sink_mode = True
|
||||
print("Start train SSD, the first epoch will be slower because of the graph compilation.")
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,208 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""metrics utils"""
|
||||
|
||||
import numpy as np
|
||||
from config import ConfigSSD
|
||||
from dataset import ssd_bboxes_decode
|
||||
|
||||
|
||||
def calc_iou(bbox_pred, bbox_ground):
|
||||
"""Calculate iou of predicted bbox and ground truth."""
|
||||
bbox_pred = np.expand_dims(bbox_pred, axis=0)
|
||||
|
||||
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
|
||||
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
|
||||
pred_area = pred_w * pred_h
|
||||
|
||||
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
|
||||
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
|
||||
gt_area = gt_w * gt_h
|
||||
|
||||
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
|
||||
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
|
||||
|
||||
iw = np.maximum(iw, 0)
|
||||
ih = np.maximum(ih, 0)
|
||||
intersection_area = iw * ih
|
||||
|
||||
union_area = pred_area + gt_area - intersection_area
|
||||
union_area = np.maximum(union_area, np.finfo(float).eps)
|
||||
|
||||
iou = intersection_area * 1. / union_area
|
||||
return iou
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
x1 = all_boxes[:, 0]
|
||||
y1 = all_boxes[:, 1]
|
||||
x2 = all_boxes[:, 2]
|
||||
y2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def calc_ap(recall, precision):
|
||||
"""Calculate AP."""
|
||||
correct_recall = np.concatenate(([0.], recall, [1.]))
|
||||
correct_precision = np.concatenate(([0.], precision, [0.]))
|
||||
|
||||
for i in range(correct_recall.size - 1, 0, -1):
|
||||
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
|
||||
|
||||
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
|
||||
|
||||
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
|
||||
|
||||
return ap
|
||||
|
||||
def metrics(pred_data):
|
||||
"""Calculate mAP of predicted bboxes."""
|
||||
config = ConfigSSD()
|
||||
num_classes = config.NUM_CLASSES
|
||||
|
||||
all_detections = [None for i in range(num_classes)]
|
||||
all_pred_scores = [None for i in range(num_classes)]
|
||||
all_annotations = [None for i in range(num_classes)]
|
||||
average_precisions = {}
|
||||
num = [0 for i in range(num_classes)]
|
||||
accurate_num = [0 for i in range(num_classes)]
|
||||
|
||||
for sample in pred_data:
|
||||
pred_boxes = sample['boxes']
|
||||
boxes_scores = sample['box_scores']
|
||||
annotation = sample['annotation']
|
||||
image_shape = sample['image_shape']
|
||||
|
||||
annotation = np.squeeze(annotation, axis=0)
|
||||
image_shape = np.squeeze(image_shape, axis=0)
|
||||
|
||||
pred_labels = np.argmax(boxes_scores, axis=-1)
|
||||
index = np.nonzero(pred_labels)
|
||||
pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape)
|
||||
|
||||
pred_boxes = pred_boxes.clip(0, 1)
|
||||
boxes_scores = np.max(boxes_scores, axis=-1)
|
||||
boxes_scores = boxes_scores[index]
|
||||
pred_labels = pred_labels[index]
|
||||
|
||||
top_k = 50
|
||||
|
||||
for c in range(1, num_classes):
|
||||
if len(pred_labels) >= 1:
|
||||
class_box_scores = boxes_scores[pred_labels == c]
|
||||
class_boxes = pred_boxes[pred_labels == c]
|
||||
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
|
||||
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
|
||||
cmask = class_box_scores > 0.5
|
||||
class_boxes = class_boxes[cmask]
|
||||
class_box_scores = class_box_scores[cmask]
|
||||
|
||||
all_detections[c] = class_boxes
|
||||
all_pred_scores[c] = class_box_scores
|
||||
|
||||
for c in range(1, num_classes):
|
||||
if len(annotation) >= 1:
|
||||
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
|
||||
|
||||
for c in range(1, num_classes):
|
||||
false_positives = np.zeros((0,))
|
||||
true_positives = np.zeros((0,))
|
||||
scores = np.zeros((0,))
|
||||
num_annotations = 0.0
|
||||
|
||||
annotations = all_annotations[c]
|
||||
num_annotations += annotations.shape[0]
|
||||
detections = all_detections[c]
|
||||
pred_scores = all_pred_scores[c]
|
||||
|
||||
for index, detection in enumerate(detections):
|
||||
scores = np.append(scores, pred_scores[index])
|
||||
if len(annotations) >= 1:
|
||||
IoUs = calc_iou(detection, annotations)
|
||||
assigned_anno = np.argmax(IoUs)
|
||||
max_overlap = IoUs[assigned_anno]
|
||||
|
||||
if max_overlap >= 0.5:
|
||||
false_positives = np.append(false_positives, 0)
|
||||
true_positives = np.append(true_positives, 1)
|
||||
else:
|
||||
false_positives = np.append(false_positives, 1)
|
||||
true_positives = np.append(true_positives, 0)
|
||||
else:
|
||||
false_positives = np.append(false_positives, 1)
|
||||
true_positives = np.append(true_positives, 0)
|
||||
|
||||
if num_annotations == 0:
|
||||
if c not in average_precisions.keys():
|
||||
average_precisions[c] = 0
|
||||
continue
|
||||
accurate_num[c] = 1
|
||||
indices = np.argsort(-scores)
|
||||
false_positives = false_positives[indices]
|
||||
true_positives = true_positives[indices]
|
||||
|
||||
false_positives = np.cumsum(false_positives)
|
||||
true_positives = np.cumsum(true_positives)
|
||||
|
||||
recall = true_positives * 1. / num_annotations
|
||||
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
|
||||
|
||||
average_precision = calc_ap(recall, precision)
|
||||
|
||||
if c not in average_precisions.keys():
|
||||
average_precisions[c] = average_precision
|
||||
else:
|
||||
average_precisions[c] += average_precision
|
||||
|
||||
num[c] += 1
|
||||
|
||||
count = 0
|
||||
for key in average_precisions:
|
||||
if num[key] != 0:
|
||||
count += (average_precisions[key] / num[key])
|
||||
|
||||
mAP = count * 1. / accurate_num.count(1)
|
||||
return mAP
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue