parent
e8c8b40e5a
commit
aeb6d2a59f
@ -0,0 +1,92 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Anchor Generator"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GridAnchorGenerator:
|
||||
"""
|
||||
Anchor Generator
|
||||
"""
|
||||
def __init__(self, image_shape, scale, scales_per_octave, aspect_ratios):
|
||||
super(GridAnchorGenerator, self).__init__()
|
||||
self.scale = scale
|
||||
self.scales_per_octave = scales_per_octave
|
||||
self.aspect_ratios = aspect_ratios
|
||||
self.image_shape = image_shape
|
||||
|
||||
|
||||
def generate(self, step):
|
||||
scales = np.array([2**(float(scale) / self.scales_per_octave)
|
||||
for scale in range(self.scales_per_octave)]).astype(np.float32)
|
||||
aspects = np.array(list(self.aspect_ratios)).astype(np.float32)
|
||||
|
||||
scales_grid, aspect_ratios_grid = np.meshgrid(scales, aspects)
|
||||
scales_grid = scales_grid.reshape([-1])
|
||||
aspect_ratios_grid = aspect_ratios_grid.reshape([-1])
|
||||
|
||||
feature_size = [self.image_shape[0] / step, self.image_shape[0] / step]
|
||||
grid_height, grid_width = feature_size
|
||||
|
||||
base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32)
|
||||
anchor_offset = step / 2.0
|
||||
|
||||
ratio_sqrt = np.sqrt(aspect_ratios_grid)
|
||||
heights = scales_grid / ratio_sqrt * base_size[0]
|
||||
widths = scales_grid * ratio_sqrt * base_size[1]
|
||||
|
||||
y_centers = np.arange(grid_height).astype(np.float32)
|
||||
y_centers = y_centers * step + anchor_offset
|
||||
x_centers = np.arange(grid_width).astype(np.float32)
|
||||
x_centers = x_centers * step + anchor_offset
|
||||
x_centers, y_centers = np.meshgrid(x_centers, y_centers)
|
||||
|
||||
x_centers_shape = x_centers.shape
|
||||
y_centers_shape = y_centers.shape
|
||||
|
||||
widths_grid, x_centers_grid = np.meshgrid(widths, x_centers.reshape([-1]))
|
||||
heights_grid, y_centers_grid = np.meshgrid(heights, y_centers.reshape([-1]))
|
||||
|
||||
x_centers_grid = x_centers_grid.reshape(*x_centers_shape, -1)
|
||||
y_centers_grid = y_centers_grid.reshape(*y_centers_shape, -1)
|
||||
widths_grid = widths_grid.reshape(-1, *x_centers_shape)
|
||||
heights_grid = heights_grid.reshape(-1, *y_centers_shape)
|
||||
|
||||
|
||||
bbox_centers = np.stack([y_centers_grid, x_centers_grid], axis=3)
|
||||
bbox_sizes = np.stack([heights_grid, widths_grid], axis=3)
|
||||
bbox_centers = bbox_centers.reshape([-1, 2])
|
||||
bbox_sizes = bbox_sizes.reshape([-1, 2])
|
||||
bbox_corners = np.concatenate([bbox_centers - 0.5 * bbox_sizes, bbox_centers + 0.5 * bbox_sizes], axis=1)
|
||||
self.bbox_corners = bbox_corners / np.array([*self.image_shape, *self.image_shape]).astype(np.float32)
|
||||
self.bbox_centers = np.concatenate([bbox_centers, bbox_sizes], axis=1)
|
||||
self.bbox_centers = self.bbox_centers / np.array([*self.image_shape, *self.image_shape]).astype(np.float32)
|
||||
|
||||
print(self.bbox_centers.shape)
|
||||
return self.bbox_centers, self.bbox_corners
|
||||
|
||||
def generate_multi_levels(self, steps):
|
||||
bbox_centers_list = []
|
||||
bbox_corners_list = []
|
||||
for step in steps:
|
||||
bbox_centers, bbox_corners = self.generate(step)
|
||||
bbox_centers_list.append(bbox_centers)
|
||||
bbox_corners_list.append(bbox_corners)
|
||||
|
||||
self.bbox_centers = np.concatenate(bbox_centers_list, axis=0)
|
||||
self.bbox_corners = np.concatenate(bbox_corners_list, axis=0)
|
||||
return self.bbox_centers, self.bbox_corners
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
|
||||
"""Config parameters for SSD models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"model": "ssd300",
|
||||
"img_shape": [300, 300],
|
||||
"num_ssd_boxes": 1917,
|
||||
"neg_pre_positive": 3,
|
||||
"match_threshold": 0.5,
|
||||
"nms_threshold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
|
||||
# learing rate settings
|
||||
"lr_init": 0.001,
|
||||
"lr_end_rate": 0.001,
|
||||
"warmup_epochs": 2,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [3, 6, 6, 6, 6, 6],
|
||||
"extras_in_channels": [256, 576, 1280, 512, 256, 256],
|
||||
"extras_out_channels": [576, 1280, 512, 256, 256, 128],
|
||||
"extras_strides": [1, 1, 2, 2, 2, 2],
|
||||
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
|
||||
"feature_size": [19, 10, 5, 3, 2, 1],
|
||||
"min_scale": 0.2,
|
||||
"max_scale": 0.95,
|
||||
"aspect_ratios": [(), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
|
||||
"steps": (16, 32, 64, 100, 150, 300),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"feature_extractor_base_param": "",
|
||||
"mindrecord_dir": "/data/MindRecord_COCO",
|
||||
"coco_root": "/data/coco2017",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "annotations/instances_{}.json",
|
||||
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_json": "annotations/voc_instances_val.json",
|
||||
# voc original dataset.
|
||||
"voc_root": "/data/voc_dataset",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": "",
|
||||
"export_format": "MINDIR",
|
||||
"export_file": "ssd.mindir"
|
||||
})
|
@ -0,0 +1,88 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
|
||||
"""Config parameters for SSD models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"model": "ssd_mobilenet_v1_fpn",
|
||||
"img_shape": [640, 640],
|
||||
"num_ssd_boxes": 51150,
|
||||
"neg_pre_positive": 3,
|
||||
"match_threshold": 0.5,
|
||||
"nms_threshold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
|
||||
# learning rate settings
|
||||
"global_step": 0,
|
||||
"lr_init": 0.01333,
|
||||
"lr_end_rate": 0.0,
|
||||
"warmup_epochs": 2,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [6, 6, 6, 6, 6],
|
||||
"extras_in_channels": [256, 512, 1024, 256, 256],
|
||||
"extras_out_channels": [256, 256, 256, 256, 256],
|
||||
"extras_strides": [1, 1, 2, 2, 2, 2],
|
||||
"extras_ratio": [0.2, 0.2, 0.2, 0.25, 0.5, 0.25],
|
||||
"feature_size": [80, 40, 20, 10, 5],
|
||||
"min_scale": 0.2,
|
||||
"max_scale": 0.95,
|
||||
"aspect_ratios": [(2, 3), (2, 3), (2, 3), (2, 3), (2, 3), (2, 3)],
|
||||
"steps": (8, 16, 32, 64, 128),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
"num_addition_layers": 4,
|
||||
"use_anchor_generator": True,
|
||||
"use_global_norm": True,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt",
|
||||
"mindrecord_dir": "/data/MindRecord_COCO",
|
||||
"coco_root": "/data/coco2017",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "annotations/instances_{}.json",
|
||||
"classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_json": "annotations/voc_instances_val.json",
|
||||
# voc original dataset.
|
||||
"voc_root": "/data/voc_dataset",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": "",
|
||||
"export_format": "MINDIR",
|
||||
"export_file": "ssd.mindir"
|
||||
})
|
@ -0,0 +1,192 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
|
||||
output = []
|
||||
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same",
|
||||
group=1 if not depthwise else in_channel))
|
||||
output.append(nn.BatchNorm2d(out_channel))
|
||||
if activation:
|
||||
output.append(nn.get_activation(activation))
|
||||
return nn.SequentialCell(output)
|
||||
|
||||
|
||||
class MobileNetV1(nn.Cell):
|
||||
"""
|
||||
MobileNet V1 backbone
|
||||
"""
|
||||
def __init__(self, class_num=1001, features_only=False):
|
||||
super(MobileNetV1, self).__init__()
|
||||
self.features_only = features_only
|
||||
cnn = [
|
||||
conv_bn_relu(3, 32, 3, 2, False), # Conv0
|
||||
|
||||
conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise
|
||||
conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise
|
||||
conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise
|
||||
conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise
|
||||
|
||||
conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise
|
||||
conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise
|
||||
conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise
|
||||
conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise
|
||||
|
||||
conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise
|
||||
conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise
|
||||
conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise
|
||||
conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise
|
||||
|
||||
conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise
|
||||
conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise
|
||||
conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise
|
||||
conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise
|
||||
conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise
|
||||
conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise
|
||||
conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise
|
||||
conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise
|
||||
conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise
|
||||
conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise
|
||||
|
||||
conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise
|
||||
conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise
|
||||
conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise
|
||||
conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise
|
||||
]
|
||||
|
||||
if self.features_only:
|
||||
self.network = nn.CellList(cnn)
|
||||
else:
|
||||
self.network = nn.SequentialCell(cnn)
|
||||
self.fc = nn.Dense(1024, class_num)
|
||||
|
||||
def construct(self, x):
|
||||
output = x
|
||||
if self.features_only:
|
||||
features = ()
|
||||
for block in self.network:
|
||||
output = block(output)
|
||||
features = features + (output,)
|
||||
return features
|
||||
output = self.network(x)
|
||||
output = P.ReduceMean()(output, (2, 3))
|
||||
output = self.fc(output)
|
||||
return output
|
||||
|
||||
|
||||
class FpnTopDown(nn.Cell):
|
||||
"""
|
||||
Fpn to extract features
|
||||
"""
|
||||
def __init__(self, in_channel_list, out_channels):
|
||||
super(FpnTopDown, self).__init__()
|
||||
self.lateral_convs_list_ = []
|
||||
self.fpn_convs_ = []
|
||||
for channel in in_channel_list:
|
||||
l_conv = nn.Conv2d(channel, out_channels, kernel_size=1, stride=1,
|
||||
has_bias=True, padding=0, pad_mode='same')
|
||||
fpn_conv = conv_bn_relu(out_channels, out_channels, kernel_size=3, stride=1, depthwise=False)
|
||||
self.lateral_convs_list_.append(l_conv)
|
||||
self.fpn_convs_.append(fpn_conv)
|
||||
self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_)
|
||||
self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_)
|
||||
self.num_layers = len(in_channel_list)
|
||||
|
||||
def construct(self, inputs):
|
||||
image_features = ()
|
||||
for i, feature in enumerate(inputs):
|
||||
image_features = image_features + (self.lateral_convs_list[i](feature),)
|
||||
|
||||
features = (image_features[-1],)
|
||||
for i in range(len(inputs) - 1):
|
||||
top = len(inputs) - i - 1
|
||||
down = top - 1
|
||||
size = F.shape(inputs[down])
|
||||
top_down = P.ResizeBilinear((size[2], size[3]))(features[-1])
|
||||
top_down = top_down + image_features[down]
|
||||
features = features + (top_down,)
|
||||
|
||||
extract_features = ()
|
||||
num_features = len(features)
|
||||
for i in range(num_features):
|
||||
extract_features = extract_features + (self.fpn_convs_list[i](features[num_features - i - 1]),)
|
||||
|
||||
return extract_features
|
||||
|
||||
|
||||
class BottomUp(nn.Cell):
|
||||
"""
|
||||
Bottom Up feature extractor
|
||||
"""
|
||||
def __init__(self, levels, channels, kernel_size, stride):
|
||||
super(BottomUp, self).__init__()
|
||||
self.levels = levels
|
||||
bottom_up_cells = [
|
||||
conv_bn_relu(channels, channels, kernel_size, stride, False) for x in range(self.levels)
|
||||
]
|
||||
self.blocks = nn.CellList(bottom_up_cells)
|
||||
|
||||
def construct(self, features):
|
||||
for block in self.blocks:
|
||||
features = features + (block(features[-1]),)
|
||||
return features
|
||||
|
||||
|
||||
class FeatureSelector(nn.Cell):
|
||||
"""
|
||||
Select specific layers from an entire feature list
|
||||
"""
|
||||
def __init__(self, feature_idxes):
|
||||
super(FeatureSelector, self).__init__()
|
||||
self.feature_idxes = feature_idxes
|
||||
|
||||
def construct(self, feature_list):
|
||||
selected = ()
|
||||
for i in self.feature_idxes:
|
||||
selected = selected + (feature_list[i],)
|
||||
return selected
|
||||
|
||||
|
||||
class MobileNetV1Fpn(nn.Cell):
|
||||
"""
|
||||
MobileNetV1 with FPN as SSD backbone.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(MobileNetV1Fpn, self).__init__()
|
||||
self.mobilenet_v1 = MobileNetV1(features_only=True)
|
||||
|
||||
self.selector = FeatureSelector([10, 22, 26])
|
||||
|
||||
self.layer_indexs = [10, 22, 26]
|
||||
self.fpn = FpnTopDown([256, 512, 1024], 256)
|
||||
self.bottom_up = BottomUp(2, 256, 3, 2)
|
||||
|
||||
def construct(self, x):
|
||||
features = self.mobilenet_v1(x)
|
||||
features = self.selector(features)
|
||||
features = self.fpn(features)
|
||||
features = self.bottom_up(features)
|
||||
return features
|
||||
|
||||
|
||||
def mobilenet_v1_fpn(config):
|
||||
return MobileNetV1Fpn(config)
|
||||
|
||||
|
||||
def mobilenet_v1(class_num=1001):
|
||||
return MobileNetV1(class_num)
|
Loading…
Reference in new issue