parent
72fd41786c
commit
6000abcdf3
@ -1,88 +0,0 @@
|
|||||||
# SSD Example
|
|
||||||
|
|
||||||
## Description
|
|
||||||
|
|
||||||
SSD network based on MobileNetV2, with support for training and evaluation.
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
|
||||||
|
|
||||||
- Dataset
|
|
||||||
|
|
||||||
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
|
|
||||||
|
|
||||||
1. If coco dataset is used. **Select dataset to coco when run script.**
|
|
||||||
Install Cython and pycocotool.
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install Cython
|
|
||||||
|
|
||||||
pip install pycocotools
|
|
||||||
```
|
|
||||||
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
|
|
||||||
|
|
||||||
|
|
||||||
```
|
|
||||||
└─coco2017
|
|
||||||
├── annotations # annotation jsons
|
|
||||||
├── train2017 # train dataset
|
|
||||||
└── val2017 # infer dataset
|
|
||||||
```
|
|
||||||
|
|
||||||
2. If your own dataset is used. **Select dataset to other when run script.**
|
|
||||||
Organize the dataset infomation into a TXT file, each row in the file is as follows:
|
|
||||||
|
|
||||||
```
|
|
||||||
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
|
|
||||||
```
|
|
||||||
|
|
||||||
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
|
|
||||||
|
|
||||||
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
### Training
|
|
||||||
|
|
||||||
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
|
|
||||||
|
|
||||||
|
|
||||||
- Stand alone mode
|
|
||||||
|
|
||||||
```
|
|
||||||
python train.py --dataset coco
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
You can run ```python train.py -h``` to get more information.
|
|
||||||
|
|
||||||
|
|
||||||
- Distribute mode
|
|
||||||
|
|
||||||
```
|
|
||||||
sh run_distribute_train.sh 8 150 coco /data/hccl.json
|
|
||||||
```
|
|
||||||
|
|
||||||
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
|
|
||||||
|
|
||||||
You will get the loss value of each step as following:
|
|
||||||
|
|
||||||
```
|
|
||||||
epoch: 1 step: 455, loss is 5.8653416
|
|
||||||
epoch: 2 step: 455, loss is 5.4292373
|
|
||||||
epoch: 3 step: 455, loss is 5.458992
|
|
||||||
...
|
|
||||||
epoch: 148 step: 455, loss is 1.8340507
|
|
||||||
epoch: 149 step: 455, loss is 2.0876894
|
|
||||||
epoch: 150 step: 455, loss is 2.239692
|
|
||||||
```
|
|
||||||
|
|
||||||
### Evaluation
|
|
||||||
|
|
||||||
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
|
|
||||||
|
|
||||||
```
|
|
||||||
python eval.py --ckpt_path ssd.ckpt --dataset coco
|
|
||||||
```
|
|
||||||
|
|
||||||
You can run ```python eval.py -h``` to get more information.
|
|
@ -1,64 +0,0 @@
|
|||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
"""Config parameters for SSD models."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigSSD:
|
|
||||||
"""
|
|
||||||
Config parameters for SSD.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
ConfigSSD().
|
|
||||||
"""
|
|
||||||
IMG_SHAPE = [300, 300]
|
|
||||||
NUM_SSD_BOXES = 1917
|
|
||||||
NEG_PRE_POSITIVE = 3
|
|
||||||
MATCH_THRESHOLD = 0.5
|
|
||||||
|
|
||||||
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
|
|
||||||
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
|
|
||||||
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
|
|
||||||
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
|
|
||||||
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
|
|
||||||
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
|
|
||||||
SCALES = [21, 45, 99, 153, 207, 261, 315]
|
|
||||||
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
|
|
||||||
STEPS = (16, 32, 64, 100, 150, 300)
|
|
||||||
PRIOR_SCALING = (0.1, 0.2)
|
|
||||||
|
|
||||||
|
|
||||||
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
|
|
||||||
MINDRECORD_DIR = "MindRecord_COCO"
|
|
||||||
COCO_ROOT = "coco2017"
|
|
||||||
TRAIN_DATA_TYPE = "train2017"
|
|
||||||
VAL_DATA_TYPE = "val2017"
|
|
||||||
INSTANCES_SET = "annotations/instances_{}.json"
|
|
||||||
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
||||||
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
|
|
||||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
|
||||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
|
||||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
|
||||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
|
||||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
|
||||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
|
||||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
||||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
|
||||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
|
||||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
|
||||||
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
|
|
||||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
|
||||||
'teddy bear', 'hair drier', 'toothbrush')
|
|
||||||
NUM_CLASSES = len(COCO_CLASSES)
|
|
File diff suppressed because it is too large
Load Diff
@ -1,206 +0,0 @@
|
|||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""metrics utils"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from config import ConfigSSD
|
|
||||||
from dataset import ssd_bboxes_decode
|
|
||||||
|
|
||||||
|
|
||||||
def calc_iou(bbox_pred, bbox_ground):
|
|
||||||
"""Calculate iou of predicted bbox and ground truth."""
|
|
||||||
bbox_pred = np.expand_dims(bbox_pred, axis=0)
|
|
||||||
|
|
||||||
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
|
|
||||||
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
|
|
||||||
pred_area = pred_w * pred_h
|
|
||||||
|
|
||||||
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
|
|
||||||
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
|
|
||||||
gt_area = gt_w * gt_h
|
|
||||||
|
|
||||||
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
|
|
||||||
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
|
|
||||||
|
|
||||||
iw = np.maximum(iw, 0)
|
|
||||||
ih = np.maximum(ih, 0)
|
|
||||||
intersection_area = iw * ih
|
|
||||||
|
|
||||||
union_area = pred_area + gt_area - intersection_area
|
|
||||||
union_area = np.maximum(union_area, np.finfo(float).eps)
|
|
||||||
|
|
||||||
iou = intersection_area * 1. / union_area
|
|
||||||
return iou
|
|
||||||
|
|
||||||
|
|
||||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
|
||||||
"""Apply NMS to bboxes."""
|
|
||||||
x1 = all_boxes[:, 0]
|
|
||||||
y1 = all_boxes[:, 1]
|
|
||||||
x2 = all_boxes[:, 2]
|
|
||||||
y2 = all_boxes[:, 3]
|
|
||||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
|
||||||
|
|
||||||
order = all_scores.argsort()[::-1]
|
|
||||||
keep = []
|
|
||||||
|
|
||||||
while order.size > 0:
|
|
||||||
i = order[0]
|
|
||||||
keep.append(i)
|
|
||||||
|
|
||||||
if len(keep) >= max_boxes:
|
|
||||||
break
|
|
||||||
|
|
||||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
|
||||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
|
||||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
|
||||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
|
||||||
|
|
||||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
|
||||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
|
||||||
inter = w * h
|
|
||||||
|
|
||||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
|
||||||
|
|
||||||
inds = np.where(ovr <= thres)[0]
|
|
||||||
|
|
||||||
order = order[inds + 1]
|
|
||||||
return keep
|
|
||||||
|
|
||||||
|
|
||||||
def calc_ap(recall, precision):
|
|
||||||
"""Calculate AP."""
|
|
||||||
correct_recall = np.concatenate(([0.], recall, [1.]))
|
|
||||||
correct_precision = np.concatenate(([0.], precision, [0.]))
|
|
||||||
|
|
||||||
for i in range(correct_recall.size - 1, 0, -1):
|
|
||||||
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
|
|
||||||
|
|
||||||
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
|
|
||||||
|
|
||||||
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
|
|
||||||
|
|
||||||
return ap
|
|
||||||
|
|
||||||
def metrics(pred_data):
|
|
||||||
"""Calculate mAP of predicted bboxes."""
|
|
||||||
config = ConfigSSD()
|
|
||||||
num_classes = config.NUM_CLASSES
|
|
||||||
|
|
||||||
all_detections = [None for i in range(num_classes)]
|
|
||||||
all_pred_scores = [None for i in range(num_classes)]
|
|
||||||
all_annotations = [None for i in range(num_classes)]
|
|
||||||
average_precisions = {}
|
|
||||||
num = [0 for i in range(num_classes)]
|
|
||||||
accurate_num = [0 for i in range(num_classes)]
|
|
||||||
|
|
||||||
for sample in pred_data:
|
|
||||||
pred_boxes = sample['boxes']
|
|
||||||
boxes_scores = sample['box_scores']
|
|
||||||
annotation = sample['annotation']
|
|
||||||
|
|
||||||
annotation = np.squeeze(annotation, axis=0)
|
|
||||||
|
|
||||||
pred_labels = np.argmax(boxes_scores, axis=-1)
|
|
||||||
index = np.nonzero(pred_labels)
|
|
||||||
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
|
|
||||||
|
|
||||||
pred_boxes = pred_boxes.clip(0, 1)
|
|
||||||
boxes_scores = np.max(boxes_scores, axis=-1)
|
|
||||||
boxes_scores = boxes_scores[index]
|
|
||||||
pred_labels = pred_labels[index]
|
|
||||||
|
|
||||||
top_k = 50
|
|
||||||
|
|
||||||
for c in range(1, num_classes):
|
|
||||||
if len(pred_labels) >= 1:
|
|
||||||
class_box_scores = boxes_scores[pred_labels == c]
|
|
||||||
class_boxes = pred_boxes[pred_labels == c]
|
|
||||||
|
|
||||||
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
|
|
||||||
|
|
||||||
class_boxes = class_boxes[nms_index]
|
|
||||||
class_box_scores = class_box_scores[nms_index]
|
|
||||||
|
|
||||||
cmask = class_box_scores > 0.5
|
|
||||||
class_boxes = class_boxes[cmask]
|
|
||||||
class_box_scores = class_box_scores[cmask]
|
|
||||||
|
|
||||||
all_detections[c] = class_boxes
|
|
||||||
all_pred_scores[c] = class_box_scores
|
|
||||||
|
|
||||||
for c in range(1, num_classes):
|
|
||||||
if len(annotation) >= 1:
|
|
||||||
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
|
|
||||||
|
|
||||||
for c in range(1, num_classes):
|
|
||||||
false_positives = np.zeros((0,))
|
|
||||||
true_positives = np.zeros((0,))
|
|
||||||
scores = np.zeros((0,))
|
|
||||||
num_annotations = 0.0
|
|
||||||
|
|
||||||
annotations = all_annotations[c]
|
|
||||||
num_annotations += annotations.shape[0]
|
|
||||||
detections = all_detections[c]
|
|
||||||
pred_scores = all_pred_scores[c]
|
|
||||||
|
|
||||||
for index, detection in enumerate(detections):
|
|
||||||
scores = np.append(scores, pred_scores[index])
|
|
||||||
if len(annotations) >= 1:
|
|
||||||
IoUs = calc_iou(detection, annotations)
|
|
||||||
assigned_anno = np.argmax(IoUs)
|
|
||||||
max_overlap = IoUs[assigned_anno]
|
|
||||||
|
|
||||||
if max_overlap >= 0.5:
|
|
||||||
false_positives = np.append(false_positives, 0)
|
|
||||||
true_positives = np.append(true_positives, 1)
|
|
||||||
else:
|
|
||||||
false_positives = np.append(false_positives, 1)
|
|
||||||
true_positives = np.append(true_positives, 0)
|
|
||||||
else:
|
|
||||||
false_positives = np.append(false_positives, 1)
|
|
||||||
true_positives = np.append(true_positives, 0)
|
|
||||||
|
|
||||||
if num_annotations == 0:
|
|
||||||
if c not in average_precisions.keys():
|
|
||||||
average_precisions[c] = 0
|
|
||||||
continue
|
|
||||||
accurate_num[c] = 1
|
|
||||||
indices = np.argsort(-scores)
|
|
||||||
false_positives = false_positives[indices]
|
|
||||||
true_positives = true_positives[indices]
|
|
||||||
|
|
||||||
false_positives = np.cumsum(false_positives)
|
|
||||||
true_positives = np.cumsum(true_positives)
|
|
||||||
|
|
||||||
recall = true_positives * 1. / num_annotations
|
|
||||||
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
|
|
||||||
|
|
||||||
average_precision = calc_ap(recall, precision)
|
|
||||||
|
|
||||||
if c not in average_precisions.keys():
|
|
||||||
average_precisions[c] = average_precision
|
|
||||||
else:
|
|
||||||
average_precisions[c] += average_precision
|
|
||||||
|
|
||||||
num[c] += 1
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for key in average_precisions:
|
|
||||||
if num[key] != 0:
|
|
||||||
count += (average_precisions[key] / num[key])
|
|
||||||
|
|
||||||
mAP = count * 1. / accurate_num.count(1)
|
|
||||||
return mAP
|
|
@ -0,0 +1,119 @@
|
|||||||
|
# SSD Example
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
SSD network based on MobileNetV2, with support for training and evaluation.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||||
|
|
||||||
|
- Dataset
|
||||||
|
|
||||||
|
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
|
||||||
|
|
||||||
|
1. If coco dataset is used. **Select dataset to coco when run script.**
|
||||||
|
Install Cython and pycocotool.
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install Cython
|
||||||
|
|
||||||
|
pip install pycocotools
|
||||||
|
```
|
||||||
|
And change the coco_root and other settings you need in `config.py`. The directory structure is as follows:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
└─cocodataset
|
||||||
|
├─annotations
|
||||||
|
├─instance_train2017.json
|
||||||
|
└─instance_val2017.json
|
||||||
|
├─val2017
|
||||||
|
└─train2017
|
||||||
|
```
|
||||||
|
|
||||||
|
2. If your own dataset is used. **Select dataset to other when run script.**
|
||||||
|
Organize the dataset infomation into a TXT file, each row in the file is as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
|
||||||
|
```
|
||||||
|
|
||||||
|
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `config.py`.
|
||||||
|
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.**
|
||||||
|
|
||||||
|
|
||||||
|
- Stand alone mode
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --dataset coco
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run ```python train.py -h``` to get more information.
|
||||||
|
|
||||||
|
|
||||||
|
- Distribute mode
|
||||||
|
|
||||||
|
```
|
||||||
|
sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json
|
||||||
|
```
|
||||||
|
|
||||||
|
The input parameters are device numbers, epoch size, learning rate, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
|
||||||
|
|
||||||
|
You will get the loss value of each step as following:
|
||||||
|
|
||||||
|
```
|
||||||
|
epoch: 1 step: 458, loss is 3.1681802
|
||||||
|
epoch time: 228752.4654865265, per step time: 499.4595316299705
|
||||||
|
epoch: 2 step: 458, loss is 2.8847265
|
||||||
|
epoch time: 38912.93382644653, per step time: 84.96273761232868
|
||||||
|
epoch: 3 step: 458, loss is 2.8398118
|
||||||
|
epoch time: 38769.184827804565, per step time: 84.64887516987896
|
||||||
|
...
|
||||||
|
|
||||||
|
epoch: 498 step: 458, loss is 0.70908034
|
||||||
|
epoch time: 38771.079778671265, per step time: 84.65301261718616
|
||||||
|
epoch: 499 step: 458, loss is 0.7974688
|
||||||
|
epoch time: 38787.413120269775, per step time: 84.68867493508685
|
||||||
|
epoch: 500 step: 458, loss is 0.5548882
|
||||||
|
epoch time: 39064.8467540741, per step time: 85.29442522723602
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
for evaluation , run `eval.py` with `checkpoint_path`. `checkpoint_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
|
||||||
|
|
||||||
|
```
|
||||||
|
python eval.py --checkpoint_path ssd.ckpt --dataset coco
|
||||||
|
```
|
||||||
|
|
||||||
|
You can run ```python eval.py -h``` to get more information.
|
||||||
|
|
||||||
|
You will get the result as following:
|
||||||
|
|
||||||
|
```
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.189
|
||||||
|
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.341
|
||||||
|
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.183
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.040
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.181
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.326
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.213
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.348
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.380
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.124
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.412
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.588
|
||||||
|
|
||||||
|
========================================
|
||||||
|
|
||||||
|
mAP: 0.18937438355383837
|
||||||
|
```
|
@ -0,0 +1,165 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Bbox utils"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import itertools as it
|
||||||
|
import numpy as np
|
||||||
|
from .config import config
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratDefaultBoxes():
|
||||||
|
"""
|
||||||
|
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
|
||||||
|
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
|
||||||
|
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
fk = config.img_shape[0] / np.array(config.steps)
|
||||||
|
scale_rate = (config.max_scale - config.min_scale) / (len(config.num_default) - 1)
|
||||||
|
scales = [config.min_scale + scale_rate * i for i in range(len(config.num_default))] + [1.0]
|
||||||
|
self.default_boxes = []
|
||||||
|
for idex, feature_size in enumerate(config.feature_size):
|
||||||
|
sk1 = scales[idex]
|
||||||
|
sk2 = scales[idex + 1]
|
||||||
|
sk3 = math.sqrt(sk1 * sk2)
|
||||||
|
if idex == 0:
|
||||||
|
w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2)
|
||||||
|
all_sizes = [(0.1, 0.1), (w, h), (h, w)]
|
||||||
|
else:
|
||||||
|
all_sizes = [(sk1, sk1)]
|
||||||
|
for aspect_ratio in config.aspect_ratios[idex]:
|
||||||
|
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
|
||||||
|
all_sizes.append((w, h))
|
||||||
|
all_sizes.append((h, w))
|
||||||
|
all_sizes.append((sk3, sk3))
|
||||||
|
|
||||||
|
assert len(all_sizes) == config.num_default[idex]
|
||||||
|
|
||||||
|
for i, j in it.product(range(feature_size), repeat=2):
|
||||||
|
for w, h in all_sizes:
|
||||||
|
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
|
||||||
|
self.default_boxes.append([cy, cx, h, w])
|
||||||
|
|
||||||
|
def to_ltrb(cy, cx, h, w):
|
||||||
|
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
|
||||||
|
|
||||||
|
# For IoU calculation
|
||||||
|
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
|
||||||
|
self.default_boxes = np.array(self.default_boxes, dtype='float32')
|
||||||
|
|
||||||
|
|
||||||
|
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
|
||||||
|
default_boxes = GeneratDefaultBoxes().default_boxes
|
||||||
|
y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
|
||||||
|
vol_anchors = (x2 - x1) * (y2 - y1)
|
||||||
|
matching_threshold = config.match_thershold
|
||||||
|
|
||||||
|
|
||||||
|
def ssd_bboxes_encode(boxes):
|
||||||
|
"""
|
||||||
|
Labels anchors with ground truth inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
gt_loc: location ground truth with shape [num_anchors, 4].
|
||||||
|
gt_label: class ground truth with shape [num_anchors, 1].
|
||||||
|
num_matched_boxes: number of positives in an image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def jaccard_with_anchors(bbox):
|
||||||
|
"""Compute jaccard score a box and the anchors."""
|
||||||
|
# Intersection bbox and volume.
|
||||||
|
ymin = np.maximum(y1, bbox[0])
|
||||||
|
xmin = np.maximum(x1, bbox[1])
|
||||||
|
ymax = np.minimum(y2, bbox[2])
|
||||||
|
xmax = np.minimum(x2, bbox[3])
|
||||||
|
w = np.maximum(xmax - xmin, 0.)
|
||||||
|
h = np.maximum(ymax - ymin, 0.)
|
||||||
|
|
||||||
|
# Volumes.
|
||||||
|
inter_vol = h * w
|
||||||
|
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
|
||||||
|
jaccard = inter_vol / union_vol
|
||||||
|
return np.squeeze(jaccard)
|
||||||
|
|
||||||
|
pre_scores = np.zeros((config.num_ssd_boxes), dtype=np.float32)
|
||||||
|
t_boxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
|
||||||
|
t_label = np.zeros((config.num_ssd_boxes), dtype=np.int64)
|
||||||
|
for bbox in boxes:
|
||||||
|
label = int(bbox[4])
|
||||||
|
scores = jaccard_with_anchors(bbox)
|
||||||
|
idx = np.argmax(scores)
|
||||||
|
scores[idx] = 2.0
|
||||||
|
mask = (scores > matching_threshold)
|
||||||
|
mask = mask & (scores > pre_scores)
|
||||||
|
pre_scores = np.maximum(pre_scores, scores * mask)
|
||||||
|
t_label = mask * label + (1 - mask) * t_label
|
||||||
|
for i in range(4):
|
||||||
|
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
|
||||||
|
|
||||||
|
index = np.nonzero(t_label)
|
||||||
|
|
||||||
|
# Transform to ltrb.
|
||||||
|
bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
|
||||||
|
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
|
||||||
|
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
|
||||||
|
|
||||||
|
# Encode features.
|
||||||
|
bboxes_t = bboxes[index]
|
||||||
|
default_boxes_t = default_boxes[index]
|
||||||
|
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0])
|
||||||
|
bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.prior_scaling[1]
|
||||||
|
bboxes[index] = bboxes_t
|
||||||
|
|
||||||
|
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
|
||||||
|
return bboxes, t_label.astype(np.int32), num_match
|
||||||
|
|
||||||
|
|
||||||
|
def ssd_bboxes_decode(boxes):
|
||||||
|
"""Decode predict boxes to [y, x, h, w]"""
|
||||||
|
boxes_t = boxes.copy()
|
||||||
|
default_boxes_t = default_boxes.copy()
|
||||||
|
boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
|
||||||
|
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4]
|
||||||
|
|
||||||
|
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
|
||||||
|
|
||||||
|
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
|
||||||
|
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
|
||||||
|
|
||||||
|
return np.clip(bboxes, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def intersect(box_a, box_b):
|
||||||
|
"""Compute the intersect of two sets of boxes."""
|
||||||
|
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
|
||||||
|
min_yx = np.maximum(box_a[:, :2], box_b[:2])
|
||||||
|
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
|
||||||
|
return inter[:, 0] * inter[:, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard_numpy(box_a, box_b):
|
||||||
|
"""Compute the jaccard overlap of two sets of boxes."""
|
||||||
|
inter = intersect(box_a, box_b)
|
||||||
|
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
||||||
|
(box_a[:, 3] - box_a[:, 1]))
|
||||||
|
area_b = ((box_b[2] - box_b[0]) *
|
||||||
|
(box_b[3] - box_b[1]))
|
||||||
|
union = area_a + area_b - inter
|
||||||
|
return inter / union
|
@ -0,0 +1,127 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Coco metrics utils"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from .config import config
|
||||||
|
from .box_utils import ssd_bboxes_decode
|
||||||
|
|
||||||
|
|
||||||
|
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||||
|
"""Apply NMS to bboxes."""
|
||||||
|
y1 = all_boxes[:, 0]
|
||||||
|
x1 = all_boxes[:, 1]
|
||||||
|
y2 = all_boxes[:, 2]
|
||||||
|
x2 = all_boxes[:, 3]
|
||||||
|
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||||
|
|
||||||
|
order = all_scores.argsort()[::-1]
|
||||||
|
keep = []
|
||||||
|
|
||||||
|
while order.size > 0:
|
||||||
|
i = order[0]
|
||||||
|
keep.append(i)
|
||||||
|
|
||||||
|
if len(keep) >= max_boxes:
|
||||||
|
break
|
||||||
|
|
||||||
|
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||||
|
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||||
|
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||||
|
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||||
|
|
||||||
|
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||||
|
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||||
|
inter = w * h
|
||||||
|
|
||||||
|
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||||
|
|
||||||
|
inds = np.where(ovr <= thres)[0]
|
||||||
|
|
||||||
|
order = order[inds + 1]
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|
||||||
|
def metrics(pred_data):
|
||||||
|
"""Calculate mAP of predicted bboxes."""
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
num_classes = config.num_classes
|
||||||
|
|
||||||
|
coco_root = config.coco_root
|
||||||
|
data_type = config.val_data_type
|
||||||
|
|
||||||
|
#Classes need to train or test.
|
||||||
|
val_cls = config.coco_classes
|
||||||
|
val_cls_dict = {}
|
||||||
|
for i, cls in enumerate(val_cls):
|
||||||
|
val_cls_dict[i] = cls
|
||||||
|
|
||||||
|
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
|
||||||
|
coco_gt = COCO(anno_json)
|
||||||
|
classs_dict = {}
|
||||||
|
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
|
||||||
|
for cat in cat_ids:
|
||||||
|
classs_dict[cat["name"]] = cat["id"]
|
||||||
|
|
||||||
|
predictions = []
|
||||||
|
img_ids = []
|
||||||
|
|
||||||
|
for sample in pred_data:
|
||||||
|
pred_boxes = sample['boxes']
|
||||||
|
box_scores = sample['box_scores']
|
||||||
|
img_id = sample['img_id']
|
||||||
|
h, w = sample['image_shape']
|
||||||
|
|
||||||
|
pred_boxes = ssd_bboxes_decode(pred_boxes)
|
||||||
|
final_boxes = []
|
||||||
|
final_label = []
|
||||||
|
final_score = []
|
||||||
|
img_ids.append(img_id)
|
||||||
|
|
||||||
|
for c in range(1, num_classes):
|
||||||
|
class_box_scores = box_scores[:, c]
|
||||||
|
score_mask = class_box_scores > config.min_score
|
||||||
|
class_box_scores = class_box_scores[score_mask]
|
||||||
|
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
|
||||||
|
|
||||||
|
if score_mask.any():
|
||||||
|
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes)
|
||||||
|
class_boxes = class_boxes[nms_index]
|
||||||
|
class_box_scores = class_box_scores[nms_index]
|
||||||
|
|
||||||
|
final_boxes += class_boxes.tolist()
|
||||||
|
final_score += class_box_scores.tolist()
|
||||||
|
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
|
||||||
|
|
||||||
|
for loc, label, score in zip(final_boxes, final_label, final_score):
|
||||||
|
res = {}
|
||||||
|
res['image_id'] = img_id
|
||||||
|
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
|
||||||
|
res['score'] = score
|
||||||
|
res['category_id'] = label
|
||||||
|
predictions.append(res)
|
||||||
|
with open('predictions.json', 'w') as f:
|
||||||
|
json.dump(predictions, f)
|
||||||
|
|
||||||
|
coco_dt = coco_gt.loadRes('predictions.json')
|
||||||
|
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
|
||||||
|
E.params.imgIds = img_ids
|
||||||
|
E.evaluate()
|
||||||
|
E.accumulate()
|
||||||
|
E.summarize()
|
||||||
|
return E.stats[0]
|
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#" ============================================================================
|
||||||
|
|
||||||
|
"""Config parameters for SSD models."""
|
||||||
|
|
||||||
|
from easydict import EasyDict as ed
|
||||||
|
|
||||||
|
config = ed({
|
||||||
|
"img_shape": [300, 300],
|
||||||
|
"num_ssd_boxes": 1917,
|
||||||
|
"neg_pre_positive": 3,
|
||||||
|
"match_thershold": 0.5,
|
||||||
|
"nms_thershold": 0.6,
|
||||||
|
"min_score": 0.1,
|
||||||
|
"max_boxes": 100,
|
||||||
|
|
||||||
|
# learing rate settings
|
||||||
|
"global_step": 0,
|
||||||
|
"lr_init": 0.001,
|
||||||
|
"lr_end_rate": 0.001,
|
||||||
|
"warmup_epochs": 2,
|
||||||
|
"momentum": 0.9,
|
||||||
|
"weight_decay": 1.5e-4,
|
||||||
|
|
||||||
|
# network
|
||||||
|
"num_default": [3, 6, 6, 6, 6, 6],
|
||||||
|
"extras_in_channels": [256, 576, 1280, 512, 256, 256],
|
||||||
|
"extras_out_channels": [576, 1280, 512, 256, 256, 128],
|
||||||
|
"extras_srides": [1, 1, 2, 2, 2, 2],
|
||||||
|
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
|
||||||
|
"feature_size": [19, 10, 5, 3, 2, 1],
|
||||||
|
"min_scale": 0.2,
|
||||||
|
"max_scale": 0.95,
|
||||||
|
"aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
|
||||||
|
"steps": (16, 32, 64, 100, 150, 300),
|
||||||
|
"prior_scaling": (0.1, 0.2),
|
||||||
|
"gamma": 2.0,
|
||||||
|
"alpha": 0.75,
|
||||||
|
|
||||||
|
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||||
|
"mindrecord_dir": "/data/MindRecord_COCO",
|
||||||
|
"coco_root": "/data/coco2017",
|
||||||
|
"train_data_type": "train2017",
|
||||||
|
"val_data_type": "val2017",
|
||||||
|
"instances_set": "annotations/instances_{}.json",
|
||||||
|
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||||
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||||
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||||
|
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||||
|
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||||
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||||
|
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||||
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||||
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||||
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||||
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||||
|
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||||
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||||
|
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||||
|
'teddy bear', 'hair drier', 'toothbrush'),
|
||||||
|
"num_classes": 81,
|
||||||
|
|
||||||
|
# if coco used, `image_dir` and `anno_path` are useless.
|
||||||
|
"image_dir": "",
|
||||||
|
"anno_path": "",
|
||||||
|
})
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Parameters utils"""
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||||
|
|
||||||
|
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||||
|
"""Init the parameters in net."""
|
||||||
|
params = network.trainable_params()
|
||||||
|
for p in params:
|
||||||
|
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||||
|
if initialize_mode == 'TruncatedNormal':
|
||||||
|
p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype()))
|
||||||
|
else:
|
||||||
|
p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype())
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone_params(network, param_dict):
|
||||||
|
"""Init the parameters from pre-train model, default is mobilenetv2."""
|
||||||
|
for _, param in net.parameters_and_names():
|
||||||
|
param_name = param.name.replace('network.backbone.', '')
|
||||||
|
name_split = param_name.split('.')
|
||||||
|
if 'features_1' in param_name:
|
||||||
|
param_name = param_name.replace('features_1', 'features')
|
||||||
|
if 'features_2' in param_name:
|
||||||
|
param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:])
|
||||||
|
if param_name in param_dict:
|
||||||
|
param.set_parameter_data(param_dict[param_name].data)
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Learning rate schedule"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||||
|
"""
|
||||||
|
generate learning rate array
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_step(int): total steps of the training
|
||||||
|
lr_init(float): init learning rate
|
||||||
|
lr_end(float): end learning rate
|
||||||
|
lr_max(float): max learning rate
|
||||||
|
warmup_epochs(float): number of warmup epochs
|
||||||
|
total_epochs(int): total epoch of training
|
||||||
|
steps_per_epoch(int): steps of one epoch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array, learning rate array
|
||||||
|
"""
|
||||||
|
lr_each_step = []
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
warmup_steps = steps_per_epoch * warmup_epochs
|
||||||
|
for i in range(total_steps):
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||||
|
else:
|
||||||
|
lr = lr_end + \
|
||||||
|
(lr_max - lr_end) * \
|
||||||
|
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
|
||||||
|
if lr < 0.0:
|
||||||
|
lr = 0.0
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
|
||||||
|
current_step = global_step
|
||||||
|
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||||
|
learning_rate = lr_each_step[current_step:]
|
||||||
|
|
||||||
|
return learning_rate
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue