change some settings in SSD

pull/1387/head
zhaoting 5 years ago
parent 72fd41786c
commit 6000abcdf3

@ -7,6 +7,7 @@
* DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset. * DeepFM: a factorization-machine based neural network for CTR prediction on Criteo dataset.
* DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark. * DeepLabV3: significantly improves over our previous DeepLab versions without DenseCRF post-processing and attains comparable performance with other state-of-art models on the PASCAL VOC 2007 semantic image segmentation benchmark.
* Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset. * Faster-RCNN: towards real-time object detection with region proposal networks on COCO 2017 dataset.
* SSD: a single stage object detection methods on COCO 2017 dataset.
* GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset. * GoogLeNet: a deep convolutional neural network architecture codenamed Inception V1 for classification and detection on CIFAR-10 dataset.
* Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset. * Wide&Deep: jointly trained wide linear models and deep neural networks for recommender systems on Criteo dataset.
* Frontend and User Interface * Frontend and User Interface

@ -1,88 +0,0 @@
# SSD Example
## Description
SSD network based on MobileNetV2, with support for training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Dataset
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool.
```
pip install Cython
pip install pycocotools
```
And change the COCO_ROOT and other settings you need in `config.py`. The directory structure is as follows:
```
└─coco2017
├── annotations # annotation jsons
├── train2017 # train dataset
└── val2017 # infer dataset
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `IMAGE_DIR`(dataset directory) and the relative path in `ANNO_PATH`(the TXT file path), `IMAGE_DIR` and `ANNO_PATH` are setting in `config.py`.
## Running the example
### Training
To train the model, run `train.py`. If the `MINDRECORD_DIR` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `COCO_ROOT`(coco dataset) or `IMAGE_DIR` and `ANNO_PATH`(own dataset). **Note if MINDRECORD_DIR isn't empty, it will use MINDRECORD_DIR instead of raw images.**
- Stand alone mode
```
python train.py --dataset coco
```
You can run ```python train.py -h``` to get more information.
- Distribute mode
```
sh run_distribute_train.sh 8 150 coco /data/hccl.json
```
The input parameters are device numbers, epoch size, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
You will get the loss value of each step as following:
```
epoch: 1 step: 455, loss is 5.8653416
epoch: 2 step: 455, loss is 5.4292373
epoch: 3 step: 455, loss is 5.458992
...
epoch: 148 step: 455, loss is 1.8340507
epoch: 149 step: 455, loss is 2.0876894
epoch: 150 step: 455, loss is 2.239692
```
### Evaluation
for evaluation , run `eval.py` with `ckpt_path`. `ckpt_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
```
python eval.py --ckpt_path ssd.ckpt --dataset coco
```
You can run ```python eval.py -h``` to get more information.

@ -1,64 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for SSD models."""
class ConfigSSD:
"""
Config parameters for SSD.
Examples:
ConfigSSD().
"""
IMG_SHAPE = [300, 300]
NUM_SSD_BOXES = 1917
NEG_PRE_POSITIVE = 3
MATCH_THRESHOLD = 0.5
NUM_DEFAULT = [3, 6, 6, 6, 6, 6]
EXTRAS_IN_CHANNELS = [256, 576, 1280, 512, 256, 256]
EXTRAS_OUT_CHANNELS = [576, 1280, 512, 256, 256, 128]
EXTRAS_STRIDES = [1, 1, 2, 2, 2, 2]
EXTRAS_RATIO = [0.2, 0.2, 0.2, 0.25, 0.5, 0.25]
FEATURE_SIZE = [19, 10, 5, 3, 2, 1]
SCALES = [21, 45, 99, 153, 207, 261, 315]
ASPECT_RATIOS = [(1,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)]
STEPS = (16, 32, 64, 100, 150, 300)
PRIOR_SCALING = (0.1, 0.2)
# `MINDRECORD_DIR` and `COCO_ROOT` are better to use absolute path.
MINDRECORD_DIR = "MindRecord_COCO"
COCO_ROOT = "coco2017"
TRAIN_DATA_TYPE = "train2017"
VAL_DATA_TYPE = "val2017"
INSTANCES_SET = "annotations/instances_{}.json"
COCO_CLASSES = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
NUM_CLASSES = len(COCO_CLASSES)

File diff suppressed because it is too large Load Diff

@ -1,206 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""metrics utils"""
import numpy as np
from config import ConfigSSD
from dataset import ssd_bboxes_decode
def calc_iou(bbox_pred, bbox_ground):
"""Calculate iou of predicted bbox and ground truth."""
bbox_pred = np.expand_dims(bbox_pred, axis=0)
pred_w = bbox_pred[:, 2] - bbox_pred[:, 0]
pred_h = bbox_pred[:, 3] - bbox_pred[:, 1]
pred_area = pred_w * pred_h
gt_w = bbox_ground[:, 2] - bbox_ground[:, 0]
gt_h = bbox_ground[:, 3] - bbox_ground[:, 1]
gt_area = gt_w * gt_h
iw = np.minimum(bbox_pred[:, 2], bbox_ground[:, 2]) - np.maximum(bbox_pred[:, 0], bbox_ground[:, 0])
ih = np.minimum(bbox_pred[:, 3], bbox_ground[:, 3]) - np.maximum(bbox_pred[:, 1], bbox_ground[:, 1])
iw = np.maximum(iw, 0)
ih = np.maximum(ih, 0)
intersection_area = iw * ih
union_area = pred_area + gt_area - intersection_area
union_area = np.maximum(union_area, np.finfo(float).eps)
iou = intersection_area * 1. / union_area
return iou
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
x1 = all_boxes[:, 0]
y1 = all_boxes[:, 1]
x2 = all_boxes[:, 2]
y2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def calc_ap(recall, precision):
"""Calculate AP."""
correct_recall = np.concatenate(([0.], recall, [1.]))
correct_precision = np.concatenate(([0.], precision, [0.]))
for i in range(correct_recall.size - 1, 0, -1):
correct_precision[i - 1] = np.maximum(correct_precision[i - 1], correct_precision[i])
i = np.where(correct_recall[1:] != correct_recall[:-1])[0]
ap = np.sum((correct_recall[i + 1] - correct_recall[i]) * correct_precision[i + 1])
return ap
def metrics(pred_data):
"""Calculate mAP of predicted bboxes."""
config = ConfigSSD()
num_classes = config.NUM_CLASSES
all_detections = [None for i in range(num_classes)]
all_pred_scores = [None for i in range(num_classes)]
all_annotations = [None for i in range(num_classes)]
average_precisions = {}
num = [0 for i in range(num_classes)]
accurate_num = [0 for i in range(num_classes)]
for sample in pred_data:
pred_boxes = sample['boxes']
boxes_scores = sample['box_scores']
annotation = sample['annotation']
annotation = np.squeeze(annotation, axis=0)
pred_labels = np.argmax(boxes_scores, axis=-1)
index = np.nonzero(pred_labels)
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
pred_boxes = pred_boxes.clip(0, 1)
boxes_scores = np.max(boxes_scores, axis=-1)
boxes_scores = boxes_scores[index]
pred_labels = pred_labels[index]
top_k = 50
for c in range(1, num_classes):
if len(pred_labels) >= 1:
class_box_scores = boxes_scores[pred_labels == c]
class_boxes = pred_boxes[pred_labels == c]
nms_index = apply_nms(class_boxes, class_box_scores, config.MATCH_THRESHOLD, top_k)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
cmask = class_box_scores > 0.5
class_boxes = class_boxes[cmask]
class_box_scores = class_box_scores[cmask]
all_detections[c] = class_boxes
all_pred_scores[c] = class_box_scores
for c in range(1, num_classes):
if len(annotation) >= 1:
all_annotations[c] = annotation[annotation[:, 4] == c, :4]
for c in range(1, num_classes):
false_positives = np.zeros((0,))
true_positives = np.zeros((0,))
scores = np.zeros((0,))
num_annotations = 0.0
annotations = all_annotations[c]
num_annotations += annotations.shape[0]
detections = all_detections[c]
pred_scores = all_pred_scores[c]
for index, detection in enumerate(detections):
scores = np.append(scores, pred_scores[index])
if len(annotations) >= 1:
IoUs = calc_iou(detection, annotations)
assigned_anno = np.argmax(IoUs)
max_overlap = IoUs[assigned_anno]
if max_overlap >= 0.5:
false_positives = np.append(false_positives, 0)
true_positives = np.append(true_positives, 1)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
else:
false_positives = np.append(false_positives, 1)
true_positives = np.append(true_positives, 0)
if num_annotations == 0:
if c not in average_precisions.keys():
average_precisions[c] = 0
continue
accurate_num[c] = 1
indices = np.argsort(-scores)
false_positives = false_positives[indices]
true_positives = true_positives[indices]
false_positives = np.cumsum(false_positives)
true_positives = np.cumsum(true_positives)
recall = true_positives * 1. / num_annotations
precision = true_positives * 1. / np.maximum(true_positives + false_positives, np.finfo(np.float64).eps)
average_precision = calc_ap(recall, precision)
if c not in average_precisions.keys():
average_precisions[c] = average_precision
else:
average_precisions[c] += average_precision
num[c] += 1
count = 0
for key in average_precisions:
if num[key] != 0:
count += (average_precisions[key] / num[key])
mAP = count * 1. / accurate_num.count(1)
return mAP

@ -0,0 +1,119 @@
# SSD Example
## Description
SSD network based on MobileNetV2, with support for training and evaluation.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Dataset
We use coco2017 as training dataset in this example by default, and you can also use your own datasets.
1. If coco dataset is used. **Select dataset to coco when run script.**
Install Cython and pycocotool.
```
pip install Cython
pip install pycocotools
```
And change the coco_root and other settings you need in `config.py`. The directory structure is as follows:
```
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
```
2. If your own dataset is used. **Select dataset to other when run script.**
Organize the dataset infomation into a TXT file, each row in the file is as follows:
```
train2017/0000001.jpg 0,259,401,459,7 35,28,324,201,2 0,30,59,80,2
```
Each row is an image annotation which split by space, the first column is a relative path of image, the others are box and class infomations of the format [xmin,ymin,xmax,ymax,class]. We read image from an image path joined by the `image_dir`(dataset directory) and the relative path in `anno_path`(the TXT file path), `image_dir` and `anno_path` are setting in `config.py`.
## Running the example
### Training
To train the model, run `train.py`. If the `mindrecord_dir` is empty, it will generate [mindrecord](https://www.mindspore.cn/tutorial/en/master/use/data_preparation/converting_datasets.html) files by `coco_root`(coco dataset) or `iamge_dir` and `anno_path`(own dataset). **Note if mindrecord_dir isn't empty, it will use mindrecord_dir instead of raw images.**
- Stand alone mode
```
python train.py --dataset coco
```
You can run ```python train.py -h``` to get more information.
- Distribute mode
```
sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json
```
The input parameters are device numbers, epoch size, learning rate, dataset mode and [hccl json configuration file](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). **It is better to use absolute path.**
You will get the loss value of each step as following:
```
epoch: 1 step: 458, loss is 3.1681802
epoch time: 228752.4654865265, per step time: 499.4595316299705
epoch: 2 step: 458, loss is 2.8847265
epoch time: 38912.93382644653, per step time: 84.96273761232868
epoch: 3 step: 458, loss is 2.8398118
epoch time: 38769.184827804565, per step time: 84.64887516987896
...
epoch: 498 step: 458, loss is 0.70908034
epoch time: 38771.079778671265, per step time: 84.65301261718616
epoch: 499 step: 458, loss is 0.7974688
epoch time: 38787.413120269775, per step time: 84.68867493508685
epoch: 500 step: 458, loss is 0.5548882
epoch time: 39064.8467540741, per step time: 85.29442522723602
```
### Evaluation
for evaluation , run `eval.py` with `checkpoint_path`. `checkpoint_path` is the path of [checkpoint](https://www.mindspore.cn/tutorial/en/master/use/saving_and_loading_model_parameters.html) file.
```
python eval.py --checkpoint_path ssd.ckpt --dataset coco
```
You can run ```python eval.py -h``` to get more information.
You will get the result as following:
```
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.189
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.341
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.183
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.040
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.181
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.326
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.213
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.348
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.380
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.124
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.412
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.588
========================================
mAP: 0.18937438355383837
```

@ -14,49 +14,51 @@
# ============================================================================ # ============================================================================
"""Evaluation for SSD""" """Evaluation for SSD"""
import os import os
import argparse import argparse
import time import time
import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.model_zoo.ssd import SSD300, ssd_mobilenet_v2 from src.ssd import SSD300, ssd_mobilenet_v2
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from config import ConfigSSD from src.config import config
from util import metrics from src.coco_eval import metrics
def ssd_eval(dataset_path, ckpt_path): def ssd_eval(dataset_path, ckpt_path):
"""SSD evaluation.""" """SSD evaluation."""
batch_size = 1
ds = create_ssd_dataset(dataset_path, batch_size=1, repeat_num=1, is_training=False) ds = create_ssd_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
net = SSD300(ssd_mobilenet_v2(), ConfigSSD(), is_training=False) net = SSD300(ssd_mobilenet_v2(), config, is_training=False)
print("Load Checkpoint!") print("Load Checkpoint!")
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
net.init_parameters_data() net.init_parameters_data()
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net.set_train(False) net.set_train(False)
i = 1. i = batch_size
total = ds.get_dataset_size() total = ds.get_dataset_size() * batch_size
start = time.time() start = time.time()
pred_data = [] pred_data = []
print("\n========================================\n") print("\n========================================\n")
print("total images num: ", total) print("total images num: ", total)
print("Processing, please wait a moment.") print("Processing, please wait a moment.")
for data in ds.create_dict_iterator(): for data in ds.create_dict_iterator():
img_id = data['img_id']
img_np = data['image'] img_np = data['image']
image_shape = data['image_shape'] image_shape = data['image_shape']
annotation = data['annotation']
output = net(Tensor(img_np)) output = net(Tensor(img_np))
for batch_idx in range(img_np.shape[0]): for batch_idx in range(img_np.shape[0]):
pred_data.append({"boxes": output[0].asnumpy()[batch_idx], pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
"box_scores": output[1].asnumpy()[batch_idx], "box_scores": output[1].asnumpy()[batch_idx],
"annotation": annotation, "img_id": int(np.squeeze(img_id[batch_idx])),
"image_shape": image_shape}) "image_shape": image_shape[batch_idx]})
percent = round(i / total * 100, 2) percent = round(i / total * 100., 2)
print(f' {str(percent)} [{i}/{total}]', end='\r') print(f' {str(percent)} [{i}/{total}]', end='\r')
i += 1 i += batch_size
cost_time = int((time.time() - start) * 1000) cost_time = int((time.time() - start) * 1000)
print(f' 100% [{total}/{total}] cost {cost_time} ms') print(f' 100% [{total}/{total}] cost {cost_time} ms')
mAP = metrics(pred_data) mAP = metrics(pred_data)
@ -73,22 +75,21 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
config = ConfigSSD()
prefix = "ssd_eval.mindrecord" prefix = "ssd_eval.mindrecord"
mindrecord_dir = config.MINDRECORD_DIR mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file): if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir): if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir) os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco": if args_opt.dataset == "coco":
if os.path.isdir(config.COCO_ROOT): if os.path.isdir(config.coco_root):
print("Create Mindrecord.") print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", False, prefix) data_to_mindrecord_byte_image("coco", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else: else:
print("COCO_ROOT not exits.") print("coco_root not exits.")
else: else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord.") print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", False, prefix) data_to_mindrecord_byte_image("other", False, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) print("Create Mindrecord Done, at {}".format(mindrecord_dir))

@ -14,17 +14,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
echo "=================================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the scipt as: "
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
echo "for example: sh run_distribute_train.sh 8 350 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" echo "for example: sh run_distribute_train.sh 8 500 0.2 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)"
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script."
echo "=================================================================================================================" echo "================================================================================================================="
if [ $# != 4 ] && [ $# != 6 ] if [ $# != 5 ] && [ $# != 7 ]
then then
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [DATASET] \ echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
[MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" [MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
exit 1 exit 1
fi fi
@ -36,38 +35,39 @@ echo "After running the scipt, the network runs in the background. The log will
export RANK_SIZE=$1 export RANK_SIZE=$1
EPOCH_SIZE=$2 EPOCH_SIZE=$2
DATASET=$3 LR=$3
PRE_TRAINED=$5 DATASET=$4
PRE_TRAINED_EPOCH_SIZE=$6 PRE_TRAINED=$6
export MINDSPORE_HCCL_CONFIG_PATH=$4 PRE_TRAINED_EPOCH_SIZE=$7
export MINDSPORE_HCCL_CONFIG_PATH=$5
for((i=0;i<RANK_SIZE;i++)) for((i=0;i<RANK_SIZE;i++))
do do
export DEVICE_ID=$i export DEVICE_ID=$i
rm -rf LOG$i rm -rf LOG$i
mkdir ./LOG$i mkdir ./LOG$i
cp *.py ./LOG$i cp ../*.py ./LOG$i
cp -r ../src ./LOG$i
cd ./LOG$i || exit cd ./LOG$i || exit
export RANK_ID=$i export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID" echo "start training for rank $i, device $DEVICE_ID"
env > env.log env > env.log
if [ $# == 4 ] if [ $# == 5 ]
then then
python ../train.py \ python train.py \
--distribute=1 \ --distribute=1 \
--lr=0.4 \ --lr=$LR \
--dataset=$DATASET \ --dataset=$DATASET \
--device_num=$RANK_SIZE \ --device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 & --epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
fi fi
if [ $# == 6 ] if [ $# == 7 ]
then then
python ../train.py \ python train.py \
--distribute=1 \ --distribute=1 \
--lr=0.4 \ --lr=$LR \
--dataset=$DATASET \ --dataset=$DATASET \
--device_num=$RANK_SIZE \ --device_num=$RANK_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \

@ -0,0 +1,165 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Bbox utils"""
import math
import itertools as it
import numpy as np
from .config import config
class GeneratDefaultBoxes():
"""
Generate Default boxes for SSD, follows the order of (W, H, archor_sizes).
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
"""
def __init__(self):
fk = config.img_shape[0] / np.array(config.steps)
scale_rate = (config.max_scale - config.min_scale) / (len(config.num_default) - 1)
scales = [config.min_scale + scale_rate * i for i in range(len(config.num_default))] + [1.0]
self.default_boxes = []
for idex, feature_size in enumerate(config.feature_size):
sk1 = scales[idex]
sk2 = scales[idex + 1]
sk3 = math.sqrt(sk1 * sk2)
if idex == 0:
w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2)
all_sizes = [(0.1, 0.1), (w, h), (h, w)]
else:
all_sizes = [(sk1, sk1)]
for aspect_ratio in config.aspect_ratios[idex]:
w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio)
all_sizes.append((w, h))
all_sizes.append((h, w))
all_sizes.append((sk3, sk3))
assert len(all_sizes) == config.num_default[idex]
for i, j in it.product(range(feature_size), repeat=2):
for w, h in all_sizes:
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
self.default_boxes.append([cy, cx, h, w])
def to_ltrb(cy, cx, h, w):
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
# For IoU calculation
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
self.default_boxes = np.array(self.default_boxes, dtype='float32')
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
default_boxes = GeneratDefaultBoxes().default_boxes
y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
vol_anchors = (x2 - x1) * (y2 - y1)
matching_threshold = config.match_thershold
def ssd_bboxes_encode(boxes):
"""
Labels anchors with ground truth inputs.
Args:
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].
Returns:
gt_loc: location ground truth with shape [num_anchors, 4].
gt_label: class ground truth with shape [num_anchors, 1].
num_matched_boxes: number of positives in an image.
"""
def jaccard_with_anchors(bbox):
"""Compute jaccard score a box and the anchors."""
# Intersection bbox and volume.
ymin = np.maximum(y1, bbox[0])
xmin = np.maximum(x1, bbox[1])
ymax = np.minimum(y2, bbox[2])
xmax = np.minimum(x2, bbox[3])
w = np.maximum(xmax - xmin, 0.)
h = np.maximum(ymax - ymin, 0.)
# Volumes.
inter_vol = h * w
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
jaccard = inter_vol / union_vol
return np.squeeze(jaccard)
pre_scores = np.zeros((config.num_ssd_boxes), dtype=np.float32)
t_boxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
t_label = np.zeros((config.num_ssd_boxes), dtype=np.int64)
for bbox in boxes:
label = int(bbox[4])
scores = jaccard_with_anchors(bbox)
idx = np.argmax(scores)
scores[idx] = 2.0
mask = (scores > matching_threshold)
mask = mask & (scores > pre_scores)
pre_scores = np.maximum(pre_scores, scores * mask)
t_label = mask * label + (1 - mask) * t_label
for i in range(4):
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
index = np.nonzero(t_label)
# Transform to ltrb.
bboxes = np.zeros((config.num_ssd_boxes, 4), dtype=np.float32)
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
# Encode features.
bboxes_t = bboxes[index]
default_boxes_t = default_boxes[index]
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0])
bboxes_t[:, 2:4] = np.log(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4]) / config.prior_scaling[1]
bboxes[index] = bboxes_t
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match
def ssd_bboxes_decode(boxes):
"""Decode predict boxes to [y, x, h, w]"""
boxes_t = boxes.copy()
default_boxes_t = default_boxes.copy()
boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4]
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
return np.clip(bboxes, 0, 1)
def intersect(box_a, box_b):
"""Compute the intersect of two sets of boxes."""
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
min_yx = np.maximum(box_a[:, :2], box_b[:2])
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
return inter[:, 0] * inter[:, 1]
def jaccard_numpy(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes."""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) *
(box_a[:, 3] - box_a[:, 1]))
area_b = ((box_b[2] - box_b[0]) *
(box_b[3] - box_b[1]))
union = area_a + area_b - inter
return inter / union

@ -0,0 +1,127 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Coco metrics utils"""
import os
import json
import numpy as np
from .config import config
from .box_utils import ssd_bboxes_decode
def apply_nms(all_boxes, all_scores, thres, max_boxes):
"""Apply NMS to bboxes."""
y1 = all_boxes[:, 0]
x1 = all_boxes[:, 1]
y2 = all_boxes[:, 2]
x2 = all_boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = all_scores.argsort()[::-1]
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
if len(keep) >= max_boxes:
break
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas[i] + areas[order[1:]] - inter)
inds = np.where(ovr <= thres)[0]
order = order[inds + 1]
return keep
def metrics(pred_data):
"""Calculate mAP of predicted bboxes."""
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
num_classes = config.num_classes
coco_root = config.coco_root
data_type = config.val_data_type
#Classes need to train or test.
val_cls = config.coco_classes
val_cls_dict = {}
for i, cls in enumerate(val_cls):
val_cls_dict[i] = cls
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
coco_gt = COCO(anno_json)
classs_dict = {}
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
for cat in cat_ids:
classs_dict[cat["name"]] = cat["id"]
predictions = []
img_ids = []
for sample in pred_data:
pred_boxes = sample['boxes']
box_scores = sample['box_scores']
img_id = sample['img_id']
h, w = sample['image_shape']
pred_boxes = ssd_bboxes_decode(pred_boxes)
final_boxes = []
final_label = []
final_score = []
img_ids.append(img_id)
for c in range(1, num_classes):
class_box_scores = box_scores[:, c]
score_mask = class_box_scores > config.min_score
class_box_scores = class_box_scores[score_mask]
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
if score_mask.any():
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes)
class_boxes = class_boxes[nms_index]
class_box_scores = class_box_scores[nms_index]
final_boxes += class_boxes.tolist()
final_score += class_box_scores.tolist()
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
for loc, label, score in zip(final_boxes, final_label, final_score):
res = {}
res['image_id'] = img_id
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
res['score'] = score
res['category_id'] = label
predictions.append(res)
with open('predictions.json', 'w') as f:
json.dump(predictions, f)
coco_dt = coco_gt.loadRes('predictions.json')
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
E.params.imgIds = img_ids
E.evaluate()
E.accumulate()
E.summarize()
return E.stats[0]

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""Config parameters for SSD models."""
from easydict import EasyDict as ed
config = ed({
"img_shape": [300, 300],
"num_ssd_boxes": 1917,
"neg_pre_positive": 3,
"match_thershold": 0.5,
"nms_thershold": 0.6,
"min_score": 0.1,
"max_boxes": 100,
# learing rate settings
"global_step": 0,
"lr_init": 0.001,
"lr_end_rate": 0.001,
"warmup_epochs": 2,
"momentum": 0.9,
"weight_decay": 1.5e-4,
# network
"num_default": [3, 6, 6, 6, 6, 6],
"extras_in_channels": [256, 576, 1280, 512, 256, 256],
"extras_out_channels": [576, 1280, 512, 256, 256, 128],
"extras_srides": [1, 1, 2, 2, 2, 2],
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
"feature_size": [19, 10, 5, 3, 2, 1],
"min_scale": 0.2,
"max_scale": 0.95,
"aspect_ratios": [(2,), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
"steps": (16, 32, 64, 100, 150, 300),
"prior_scaling": (0.1, 0.2),
"gamma": 2.0,
"alpha": 0.75,
# `mindrecord_dir` and `coco_root` are better to use absolute path.
"mindrecord_dir": "/data/MindRecord_COCO",
"coco_root": "/data/coco2017",
"train_data_type": "train2017",
"val_data_type": "val2017",
"instances_set": "annotations/instances_{}.json",
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush'),
"num_classes": 81,
# if coco used, `image_dir` and `anno_path` are useless.
"image_dir": "",
"anno_path": "",
})

File diff suppressed because it is too large Load Diff

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameters utils"""
from mindspore import Tensor
from mindspore.common.initializer import initializer, TruncatedNormal
def init_net_param(network, initialize_mode='TruncatedNormal'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
if initialize_mode == 'TruncatedNormal':
p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape(), p.data.dtype()))
else:
p.set_parameter_data(initialize_mode, p.data.shape(), p.data.dtype())
def load_backbone_params(network, param_dict):
"""Init the parameters from pre-train model, default is mobilenetv2."""
for _, param in net.parameters_and_names():
param_name = param.name.replace('network.backbone.', '')
name_split = param_name.split('.')
if 'features_1' in param_name:
param_name = param_name.replace('features_1', 'features')
if 'features_2' in param_name:
param_name = '.'.join(['features', str(int(name_split[1]) + 14)] + name_split[2:])
if param_name in param_dict:
param.set_parameter_data(param_dict[param_name].data)

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate schedule"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(float): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

File diff suppressed because it is too large Load Diff

@ -13,83 +13,38 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train SSD and get checkpoint files.""" """Train SSD and get checkpoint files."""
import os import os
import math
import argparse import argparse
import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from mindspore.train import Model, ParallelMode from mindspore.train import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common.initializer import initializer from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2
from src.config import config
from mindspore.model_zoo.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2 from src.dataset import create_ssd_dataset, data_to_mindrecord_byte_image
from config import ConfigSSD from src.lr_schedule import get_lr
from dataset import create_ssd_dataset, data_to_mindrecord_byte_image from src.init_params import init_net_param
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + (lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def init_net_param(network, initialize_mode='XavierUniform'):
"""Init the parameters in net."""
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(initialize_mode, p.data.shape(), p.data.dtype()))
def main(): def main():
parser = argparse.ArgumentParser(description="SSD training") parser = argparse.ArgumentParser(description="SSD training")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.") "Mindrecord, default is False.")
parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is false.") parser.add_argument("--distribute", type=bool, default=False, help="Run distribute, default is False.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.25, help="Learning rate, default is 0.25.") parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") parser.add_argument("--epoch_size", type=int, default=250, help="Epoch size, default is 250.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 5.")
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
args_opt = parser.parse_args() args_opt = parser.parse_args()
@ -111,27 +66,26 @@ def main():
# It will generate mindrecord file in args_opt.mindrecord_dir, # It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is ssd.mindrecord0, 1, ... file_num. # and the file name is ssd.mindrecord0, 1, ... file_num.
config = ConfigSSD()
prefix = "ssd.mindrecord" prefix = "ssd.mindrecord"
mindrecord_dir = config.MINDRECORD_DIR mindrecord_dir = config.mindrecord_dir
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
if not os.path.exists(mindrecord_file): if not os.path.exists(mindrecord_file):
if not os.path.isdir(mindrecord_dir): if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir) os.makedirs(mindrecord_dir)
if args_opt.dataset == "coco": if args_opt.dataset == "coco":
if os.path.isdir(config.COCO_ROOT): if os.path.isdir(config.coco_root):
print("Create Mindrecord.") print("Create Mindrecord.")
data_to_mindrecord_byte_image("coco", True, prefix) data_to_mindrecord_byte_image("coco", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else: else:
print("COCO_ROOT not exits.") print("coco_root not exits.")
else: else:
if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH): if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
print("Create Mindrecord.") print("Create Mindrecord.")
data_to_mindrecord_byte_image("other", True, prefix) data_to_mindrecord_byte_image("other", True, prefix)
print("Create Mindrecord Done, at {}".format(mindrecord_dir)) print("Create Mindrecord Done, at {}".format(mindrecord_dir))
else: else:
print("IMAGE_DIR or ANNO_PATH not exits.") print("image_dir or anno_path not exits.")
if not args_opt.only_create_dataset: if not args_opt.only_create_dataset:
loss_scale = float(args_opt.loss_scale) loss_scale = float(args_opt.loss_scale)
@ -143,7 +97,8 @@ def main():
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print("Create dataset done!") print("Create dataset done!")
ssd = SSD300(backbone=ssd_mobilenet_v2(), config=config) backbone = ssd_mobilenet_v2()
ssd = SSD300(backbone=backbone, config=config)
net = SSDWithLossCell(ssd, config) net = SSDWithLossCell(ssd, config)
init_net_param(net) init_net_param(net)
@ -157,12 +112,13 @@ def main():
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, lr = Tensor(get_lr(global_step=config.global_step,
lr_init=0, lr_end=0, lr_max=args_opt.lr, lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
warmup_epochs=max(350 // 20, 1), warmup_epochs=config.warmup_epochs,
total_epochs=350, total_epochs=args_opt.epoch_size,
steps_per_epoch=dataset_size)) steps_per_epoch=dataset_size))
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
config.momentum, config.weight_decay, loss_scale)
net = TrainingWrapper(net, opt, loss_scale) net = TrainingWrapper(net, opt, loss_scale)
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb]
Loading…
Cancel
Save