!12924 add ssd vgg backbone support
From: @caojian05 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @wuxuejianpull/12924/MERGE
commit
06a3f22834
@ -0,0 +1,84 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Config parameters for SSD models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"model": "ssd_vgg16",
|
||||
"img_shape": [300, 300],
|
||||
"num_ssd_boxes": 7308,
|
||||
"match_threshold": 0.5,
|
||||
"nms_threshold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
"ssd_vgg_bn": False,
|
||||
|
||||
# learing rate settings
|
||||
"lr_init": 0.001,
|
||||
"lr_end_rate": 0.001,
|
||||
"warmup_epochs": 2,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [3, 6, 6, 6, 6, 6],
|
||||
"extras_in_channels": [256, 512, 1024, 512, 256, 256],
|
||||
"extras_out_channels": [512, 1024, 512, 256, 256, 256],
|
||||
"extras_strides": [1, 1, 2, 2, 2, 2],
|
||||
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
|
||||
"feature_size": [38, 19, 10, 5, 3, 1],
|
||||
"min_scale": 0.2,
|
||||
"max_scale": 0.95,
|
||||
"aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
|
||||
"steps": (8, 16, 32, 64, 100, 300),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"feature_extractor_base_param": "",
|
||||
"pretrain_vgg_bn": False,
|
||||
"checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'],
|
||||
"mindrecord_dir": "/data/MindRecord_COCO",
|
||||
"coco_root": "/data/coco2017",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "annotations/instances_{}.json",
|
||||
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_json": "annotations/voc_instances_val.json",
|
||||
# voc original dataset.
|
||||
"voc_root": "/data/voc_dataset",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": ""
|
||||
})
|
@ -0,0 +1,99 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""VGG16 backbone for SSD"""
|
||||
|
||||
from mindspore import nn
|
||||
from .config_ssd_vgg16 import config
|
||||
|
||||
pretrain_vgg_bn = config.pretrain_vgg_bn
|
||||
ssd_vgg_bn = config.ssd_vgg_bn
|
||||
|
||||
|
||||
def _get_key_mapper():
|
||||
vgg_key_num = [1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]
|
||||
size = len(vgg_key_num)
|
||||
|
||||
pretrain_vgg_bn_false = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]
|
||||
pretrain_vgg_bn_true = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40]
|
||||
ssd_vgg_bn_false = [0, 2, 0, 2, 0, 2, 4, 0, 2, 4, 0, 2, 4]
|
||||
ssd_vgg_bn_true = [0, 3, 0, 3, 0, 3, 6, 0, 3, 6, 0, 3, 6]
|
||||
|
||||
pretrain_vgg_keys = pretrain_vgg_bn_true if pretrain_vgg_bn else pretrain_vgg_bn_false
|
||||
ssd_vgg_keys = ssd_vgg_bn_true if ssd_vgg_bn else ssd_vgg_bn_false
|
||||
|
||||
pretrain_vgg_keys = ['layers.' + str(pretrain_vgg_keys[i]) for i in range(size)]
|
||||
ssd_vgg_keys = ['b' + str(vgg_key_num[i]) + '.' + str(ssd_vgg_keys[i]) for i in range(size)]
|
||||
|
||||
return {pretrain_vgg_keys[i]: ssd_vgg_keys[i] for i in range(size)}
|
||||
|
||||
|
||||
ssd_vgg_key_mapper = _get_key_mapper()
|
||||
|
||||
|
||||
def _make_layer(channels):
|
||||
in_channels = channels[0]
|
||||
layers = []
|
||||
for out_channels in channels[1:]:
|
||||
layers.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3))
|
||||
if ssd_vgg_bn:
|
||||
layers.append(nn.BatchNorm2d(out_channels))
|
||||
layers.append(nn.ReLU())
|
||||
in_channels = out_channels
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
|
||||
class VGG16(nn.Cell):
|
||||
def __init__(self):
|
||||
super(VGG16, self).__init__()
|
||||
self.b1 = _make_layer([3, 64, 64])
|
||||
self.b2 = _make_layer([64, 128, 128])
|
||||
self.b3 = _make_layer([128, 256, 256, 256])
|
||||
self.b4 = _make_layer([256, 512, 512, 512])
|
||||
self.b5 = _make_layer([512, 512, 512, 512])
|
||||
|
||||
self.m1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||
self.m2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||
self.m3 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||
self.m4 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='SAME')
|
||||
self.m5 = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='SAME')
|
||||
|
||||
def construct(self, x):
|
||||
# block1
|
||||
x = self.b1(x)
|
||||
x = self.m1(x)
|
||||
|
||||
# block2
|
||||
x = self.b2(x)
|
||||
x = self.m2(x)
|
||||
|
||||
# block3
|
||||
x = self.b3(x)
|
||||
x = self.m3(x)
|
||||
|
||||
# block4
|
||||
x = self.b4(x)
|
||||
block4 = x
|
||||
x = self.m4(x)
|
||||
|
||||
# block5
|
||||
x = self.b5(x)
|
||||
x = self.m5(x)
|
||||
|
||||
return block4, x
|
||||
|
||||
|
||||
def vgg16():
|
||||
return VGG16()
|
Loading…
Reference in new issue