From 0f5fcfee39b8aedf03320f42414968d68db364ac Mon Sep 17 00:00:00 2001 From: chengxianbin Date: Sun, 14 Jun 2020 10:20:01 +0800 Subject: [PATCH] add ci test case for yolov3 --- tests/st/model_zoo_tests/yolov3/src/config.py | 49 ++ .../st/model_zoo_tests/yolov3/src/dataset.py | 318 ++++++++ tests/st/model_zoo_tests/yolov3/src/yolov3.py | 748 ++++++++++++++++++ .../st/model_zoo_tests/yolov3/test_yolov3.py | 157 ++++ 4 files changed, 1272 insertions(+) create mode 100644 tests/st/model_zoo_tests/yolov3/src/config.py create mode 100644 tests/st/model_zoo_tests/yolov3/src/dataset.py create mode 100644 tests/st/model_zoo_tests/yolov3/src/yolov3.py create mode 100644 tests/st/model_zoo_tests/yolov3/test_yolov3.py diff --git a/tests/st/model_zoo_tests/yolov3/src/config.py b/tests/st/model_zoo_tests/yolov3/src/config.py new file mode 100644 index 0000000000..37bdcb944b --- /dev/null +++ b/tests/st/model_zoo_tests/yolov3/src/config.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Config parameters for YOLOv3 models.""" + + +class ConfigYOLOV3ResNet18: + """ + Config parameters for YOLOv3. + + Examples: + ConfigYoloV3ResNet18. + """ + img_shape = [352, 640] + feature_shape = [32, 3, 352, 640] + num_classes = 2 + nms_max_num = 50 + + backbone_input_shape = [64, 64, 128, 256] + backbone_shape = [64, 128, 256, 512] + backbone_layers = [2, 2, 2, 2] + backbone_stride = [1, 2, 2, 2] + + ignore_threshold = 0.5 + obj_threshold = 0.3 + nms_threshold = 0.4 + + anchor_scales = [(10, 13), + (16, 30), + (33, 23), + (30, 61), + (62, 45), + (59, 119), + (116, 90), + (156, 198), + (163, 326)] + out_channel = int(len(anchor_scales) / 3 * (num_classes + 5)) diff --git a/tests/st/model_zoo_tests/yolov3/src/dataset.py b/tests/st/model_zoo_tests/yolov3/src/dataset.py new file mode 100644 index 0000000000..e13802566b --- /dev/null +++ b/tests/st/model_zoo_tests/yolov3/src/dataset.py @@ -0,0 +1,318 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""YOLOv3 dataset""" +from __future__ import division + +import os +import numpy as np +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb +from PIL import Image +import mindspore.dataset as de +from mindspore.mindrecord import FileWriter +import mindspore.dataset.transforms.vision.c_transforms as C +from src.config import ConfigYOLOV3ResNet18 + +iter_cnt = 0 +_NUM_BOXES = 50 +np.random.seed(1) +de.config.set_seed(1) + +def preprocess_fn(image, box, is_training): + """Preprocess function for dataset.""" + config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326] + anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2) + do_hsv = False + max_boxes = 20 + num_classes = ConfigYOLOV3ResNet18.num_classes + + def _rand(a=0., b=1.): + return np.random.rand() * (b - a) + a + + def _preprocess_true_boxes(true_boxes, anchors, in_shape=None): + """Get true boxes.""" + num_layers = anchors.shape[0] // 3 + anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + true_boxes = np.array(true_boxes, dtype='float32') + # input_shape = np.array([in_shape, in_shape], dtype='int32') + input_shape = np.array(in_shape, dtype='int32') + boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2. + boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2] + true_boxes[..., 0:2] = boxes_xy / input_shape[::-1] + true_boxes[..., 2:4] = boxes_wh / input_shape[::-1] + + grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8] + y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), + 5 + num_classes), dtype='float32') for l in range(num_layers)] + + anchors = np.expand_dims(anchors, 0) + anchors_max = anchors / 2. + anchors_min = -anchors_max + + valid_mask = boxes_wh[..., 0] >= 1 + + wh = boxes_wh[valid_mask] + + + if len(wh) >= 1: + wh = np.expand_dims(wh, -2) + boxes_max = wh / 2. + boxes_min = -boxes_max + + intersect_min = np.maximum(boxes_min, anchors_min) + intersect_max = np.minimum(boxes_max, anchors_max) + intersect_wh = np.maximum(intersect_max - intersect_min, 0.) + intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] + box_area = wh[..., 0] * wh[..., 1] + anchor_area = anchors[..., 0] * anchors[..., 1] + iou = intersect_area / (box_area + anchor_area - intersect_area) + + best_anchor = np.argmax(iou, axis=-1) + for t, n in enumerate(best_anchor): + for l in range(num_layers): + if n in anchor_mask[l]: + i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32') + j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32') + k = anchor_mask[l].index(n) + + c = true_boxes[t, 4].astype('int32') + y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4] + y_true[l][j, i, k, 4] = 1. + y_true[l][j, i, k, 5 + c] = 1. + + pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32) + pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32) + pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32) + + mask0 = np.reshape(y_true[0][..., 4:5], [-1]) + gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4]) + gt_box0 = gt_box0[mask0 == 1] + pad_gt_box0[:gt_box0.shape[0]] = gt_box0 + + mask1 = np.reshape(y_true[1][..., 4:5], [-1]) + gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4]) + gt_box1 = gt_box1[mask1 == 1] + pad_gt_box1[:gt_box1.shape[0]] = gt_box1 + + mask2 = np.reshape(y_true[2][..., 4:5], [-1]) + gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4]) + gt_box2 = gt_box2[mask2 == 1] + pad_gt_box2[:gt_box2.shape[0]] = gt_box2 + + return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2 + + def _infer_data(img_data, input_shape, box): + w, h = img_data.size + input_h, input_w = input_shape + scale = min(float(input_w) / float(w), float(input_h) / float(h)) + nw = int(w * scale) + nh = int(h * scale) + img_data = img_data.resize((nw, nh), Image.BICUBIC) + + new_image = np.zeros((input_h, input_w, 3), np.float32) + new_image.fill(128) + img_data = np.array(img_data) + if len(img_data.shape) == 2: + img_data = np.expand_dims(img_data, axis=-1) + img_data = np.concatenate([img_data, img_data, img_data], axis=-1) + + dh = int((input_h - nh) / 2) + dw = int((input_w - nw) / 2) + new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data + new_image /= 255. + new_image = np.transpose(new_image, (2, 0, 1)) + new_image = np.expand_dims(new_image, 0) + return new_image, np.array([h, w], np.float32), box + + def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)): + """Data augmentation function.""" + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + iw, ih = image.size + ori_image_shape = np.array([ih, iw], np.int32) + h, w = image_size + + if not is_training: + return _infer_data(image, image_size, box) + + flip = _rand() < .5 + # correct boxes + box_data = np.zeros((max_boxes, 5)) + while True: + # Prevent the situation that all boxes are eliminated + new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \ + _rand(1 - jitter, 1 + jitter) + scale = _rand(0.25, 2) + + if new_ar < 1: + nh = int(scale * h) + nw = int(nh * new_ar) + else: + nw = int(scale * w) + nh = int(nw / new_ar) + + dx = int(_rand(0, w - nw)) + dy = int(_rand(0, h - nh)) + + if len(box) >= 1: + t_box = box.copy() + np.random.shuffle(t_box) + t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx + t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy + if flip: + t_box[:, [0, 2]] = w - t_box[:, [2, 0]] + t_box[:, 0:2][t_box[:, 0:2] < 0] = 0 + t_box[:, 2][t_box[:, 2] > w] = w + t_box[:, 3][t_box[:, 3] > h] = h + box_w = t_box[:, 2] - t_box[:, 0] + box_h = t_box[:, 3] - t_box[:, 1] + t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box + + if len(t_box) >= 1: + box = t_box + break + + box_data[:len(box)] = box + # resize image + image = image.resize((nw, nh), Image.BICUBIC) + # place image + new_image = Image.new('RGB', (w, h), (128, 128, 128)) + new_image.paste(image, (dx, dy)) + image = new_image + + # flip image or not + if flip: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + + # convert image to gray or not + gray = _rand() < .25 + if gray: + image = image.convert('L').convert('RGB') + + # when the channels of image is 1 + image = np.array(image) + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + # distort image + hue = _rand(-hue, hue) + sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat) + val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val) + image_data = image / 255. + if do_hsv: + x = rgb_to_hsv(image_data) + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + image_data = hsv_to_rgb(x) # numpy array, 0 to 1 + image_data = image_data.astype(np.float32) + + # preprocess bounding boxes + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(box_data, anchors, image_size) + + return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \ + ori_image_shape, gt_box1, gt_box2, gt_box3 + + if is_training: + images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training) + return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3 + + images, shape, anno = _data_aug(image, box, is_training) + return images, shape, anno + + +def anno_parser(annos_str): + """Parse annotation from string to list.""" + annos = [] + for anno_str in annos_str: + anno = list(map(int, anno_str.strip().split(','))) + annos.append(anno) + return annos + + +def filter_valid_data(image_dir, anno_path): + """Filter valid image file, which both in image_dir and anno_path.""" + image_files = [] + image_anno_dict = {} + if not os.path.isdir(image_dir): + raise RuntimeError("Path given is not valid.") + if not os.path.isfile(anno_path): + raise RuntimeError("Annotation file is not valid.") + + with open(anno_path, "rb") as f: + lines = f.readlines() + for line in lines: + line_str = line.decode("utf-8").strip() + line_split = str(line_str).split(' ') + file_name = line_split[0] + if os.path.isfile(os.path.join(image_dir, file_name)): + image_anno_dict[file_name] = anno_parser(line_split[1:]) + image_files.append(file_name) + return image_files, image_anno_dict + + +def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8): + """Create MindRecord file by image_dir and anno_path.""" + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + image_files, image_anno_dict = filter_valid_data(image_dir, anno_path) + + yolo_json = { + "image": {"type": "bytes"}, + "annotation": {"type": "int64", "shape": [-1, 5]}, + } + writer.add_schema(yolo_json, "yolo_json") + + for image_name in image_files: + image_path = os.path.join(image_dir, image_name) + with open(image_path, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[image_name]) + row = {"image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + + +def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0, + is_training=True, num_parallel_workers=8): + """Creatr YOLOv3 dataset with MindDataset.""" + ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank, + num_parallel_workers=num_parallel_workers, shuffle=False) + decode = C.Decode() + ds = ds.map(input_columns=["image"], operations=decode) + compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) + + if is_training: + hwc_to_chw = C.HWC2CHW() + ds = ds.map(input_columns=["image", "annotation"], + output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], + columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"], + operations=compose_map_func, num_parallel_workers=num_parallel_workers) + ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + else: + ds = ds.map(input_columns=["image", "annotation"], + output_columns=["image", "image_shape", "annotation"], + columns_order=["image", "image_shape", "annotation"], + operations=compose_map_func, num_parallel_workers=num_parallel_workers) + return ds diff --git a/tests/st/model_zoo_tests/yolov3/src/yolov3.py b/tests/st/model_zoo_tests/yolov3/src/yolov3.py new file mode 100644 index 0000000000..0ac6b21070 --- /dev/null +++ b/tests/st/model_zoo_tests/yolov3/src/yolov3.py @@ -0,0 +1,748 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""YOLOv3 based on ResNet18.""" + +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C + + +def weight_variable(): + """Weight variable.""" + return TruncatedNormal(0.02) + + +class _conv2d(nn.Cell): + """Create Conv2D with padding.""" + def __init__(self, in_channels, out_channels, kernel_size, stride=1): + super(_conv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=0, pad_mode='same', + weight_init=weight_variable()) + def construct(self, x): + x = self.conv(x) + return x + + +def _fused_bn(channels, momentum=0.99): + """Get a fused batchnorm.""" + return nn.BatchNorm2d(channels, momentum=momentum) + + +def _conv_bn_relu(in_channel, + out_channel, + ksize, + stride=1, + padding=0, + dilation=1, + alpha=0.1, + momentum=0.99, + pad_mode="same"): + """Get a conv2d batchnorm and relu layer.""" + return nn.SequentialCell( + [nn.Conv2d(in_channel, + out_channel, + kernel_size=ksize, + stride=stride, + padding=padding, + dilation=dilation, + pad_mode=pad_mode), + nn.BatchNorm2d(out_channel, momentum=momentum), + nn.LeakyReLU(alpha)] + ) + + +class BasicBlock(nn.Cell): + """ + ResNet basic block. + + Args: + in_channels (int): Input channel. + out_channels (int): Output channel. + stride (int): Stride size for the initial convolutional layer. Default:1. + momentum (float): Momentum for batchnorm layer. Default:0.1. + + Returns: + Tensor, output tensor. + + Examples: + BasicBlock(3,256,stride=2,down_sample=True). + """ + expansion = 1 + + def __init__(self, + in_channels, + out_channels, + stride=1, + momentum=0.99): + super(BasicBlock, self).__init__() + + self.conv1 = _conv2d(in_channels, out_channels, 3, stride=stride) + self.bn1 = _fused_bn(out_channels, momentum=momentum) + self.conv2 = _conv2d(out_channels, out_channels, 3) + self.bn2 = _fused_bn(out_channels, momentum=momentum) + self.relu = P.ReLU() + self.down_sample_layer = None + self.downsample = (in_channels != out_channels) + if self.downsample: + self.down_sample_layer = _conv2d(in_channels, out_channels, 1, stride=stride) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + + if self.downsample: + identity = self.down_sample_layer(identity) + + out = self.add(x, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet network. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of different layers. + in_channels (int): Input channel. + out_channels (int): Output channel. + num_classes (int): Class number. Default:100. + + Returns: + Tensor, output tensor. + + Examples: + ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + 100). + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides=None, + num_classes=80): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of " + "layer_num, inchannel, outchannel list must be 4!") + + self.conv1 = _conv2d(3, 64, 7, stride=2) + self.bn1 = _fused_bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.num_classes = num_classes + if num_classes: + self.reduce_mean = P.ReduceMean(keep_dims=True) + self.end_point = nn.Dense(out_channels[3], num_classes, has_bias=True, + weight_init=weight_variable(), + bias_init=weight_variable()) + self.squeeze = P.Squeeze(axis=(2, 3)) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make Layer for ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the initial convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + _make_layer(BasicBlock, 3, 128, 256, 2). + """ + layers = [] + + resblk = block(in_channel, out_channel, stride=stride) + layers.append(resblk) + + for _ in range(1, layer_num - 1): + resblk = block(out_channel, out_channel, stride=1) + layers.append(resblk) + + resblk = block(out_channel, out_channel, stride=1) + layers.append(resblk) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = c5 + if self.num_classes: + out = self.reduce_mean(c5, (2, 3)) + out = self.squeeze(out) + out = self.end_point(out) + + return c3, c4, out + + +def resnet18(class_num=10): + """ + Get ResNet18 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet18 neural network. + + Examples: + resnet18(100). + """ + return ResNet(BasicBlock, + [2, 2, 2, 2], + [64, 64, 128, 256], + [64, 128, 256, 512], + [1, 2, 2, 2], + num_classes=class_num) + + +class YoloBlock(nn.Cell): + """ + YoloBlock for YOLOv3. + + Args: + in_channels (int): Input channel. + out_chls (int): Middle channel. + out_channels (int): Output channel. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3). + + Examples: + YoloBlock(1024, 512, 255). + + """ + def __init__(self, in_channels, out_chls, out_channels): + super(YoloBlock, self).__init__() + out_chls_2 = out_chls * 2 + + self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1) + self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) + + self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) + self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) + + self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1) + self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3) + + self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True) + + def construct(self, x): + c1 = self.conv0(x) + c2 = self.conv1(c1) + + c3 = self.conv2(c2) + c4 = self.conv3(c3) + + c5 = self.conv4(c4) + c6 = self.conv5(c5) + + out = self.conv6(c6) + return c5, out + + +class YOLOv3(nn.Cell): + """ + YOLOv3 Network. + + Note: + backbone = resnet18. + + Args: + feature_shape (list): Input image shape, [N,C,H,W]. + backbone_shape (list): resnet18 output channels shape. + backbone (Cell): Backbone Network. + out_channel (int): Output channel. + + Returns: + Tensor, output tensor. + + Examples: + YOLOv3(feature_shape=[1,3,416,416], + backbone_shape=[64, 128, 256, 512, 1024] + backbone=darknet53(), + out_channel=255). + """ + def __init__(self, feature_shape, backbone_shape, backbone, out_channel): + super(YOLOv3, self).__init__() + self.out_channel = out_channel + self.net = backbone + self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel) + + self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1) + self.upsample1 = P.ResizeNearestNeighbor((feature_shape[2]//16, feature_shape[3]//16)) + self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3], + out_chls=backbone_shape[-3], + out_channels=out_channel) + + self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1) + self.upsample2 = P.ResizeNearestNeighbor((feature_shape[2]//8, feature_shape[3]//8)) + self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4], + out_chls=backbone_shape[-4], + out_channels=out_channel) + self.concat = P.Concat(axis=1) + + def construct(self, x): + # input_shape of x is (batch_size, 3, h, w) + # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8) + # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16) + # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32) + feature_map1, feature_map2, feature_map3 = self.net(x) + con1, big_object_output = self.backblock0(feature_map3) + + con1 = self.conv1(con1) + ups1 = self.upsample1(con1) + con1 = self.concat((ups1, feature_map2)) + con2, medium_object_output = self.backblock1(con1) + + con2 = self.conv2(con2) + ups2 = self.upsample2(con2) + con3 = self.concat((ups2, feature_map1)) + _, small_object_output = self.backblock2(con3) + + return big_object_output, medium_object_output, small_object_output + + +class DetectionBlock(nn.Cell): + """ + YOLOv3 detection Network. It will finally output the detection result. + + Args: + scale (str): Character, scale. + config (Class): YOLOv3 config. + + Returns: + Tuple, tuple of output tensor,(f1,f2,f3). + + Examples: + DetectionBlock(scale='l',stride=32). + """ + + def __init__(self, scale, config): + super(DetectionBlock, self).__init__() + + self.config = config + if scale == 's': + idx = (0, 1, 2) + elif scale == 'm': + idx = (3, 4, 5) + elif scale == 'l': + idx = (6, 7, 8) + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) + self.num_anchors_per_scale = 3 + self.num_attrib = 4 + 1 + self.config.num_classes + self.ignore_threshold = 0.5 + self.lambda_coord = 1 + + self.sigmoid = nn.Sigmoid() + self.reshape = P.Reshape() + self.tile = P.Tile() + self.concat = P.Concat(axis=-1) + self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32) + + def construct(self, x): + num_batch = P.Shape()(x)[0] + grid_size = P.Shape()(x)[2:4] + + # Reshape and transpose the feature to [n, 3, grid_size[0], grid_size[1], num_attrib] + prediction = P.Reshape()(x, (num_batch, + self.num_anchors_per_scale, + self.num_attrib, + grid_size[0], + grid_size[1])) + prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2)) + + range_x = range(grid_size[1]) + range_y = range(grid_size[0]) + grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32) + grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32) + # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid + grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1)) + grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1)) + # Shape is [grid_size[0], grid_size[1], 1, 2] + grid = self.concat((grid_x, grid_y)) + + box_xy = prediction[:, :, :, :, :2] + box_wh = prediction[:, :, :, :, 2:4] + box_confidence = prediction[:, :, :, :, 4:5] + box_probs = prediction[:, :, :, :, 5:] + + box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32) + box_wh = P.Exp()(box_wh) * self.anchors / self.input_shape + box_confidence = self.sigmoid(box_confidence) + box_probs = self.sigmoid(box_probs) + + if self.training: + return grid, prediction, box_xy, box_wh + return box_xy, box_wh, box_confidence, box_probs + + +class Iou(nn.Cell): + """Calculate the iou of boxes.""" + def __init__(self): + super(Iou, self).__init__() + self.min = P.Minimum() + self.max = P.Maximum() + + def construct(self, box1, box2): + box1_xy = box1[:, :, :, :, :, :2] + box1_wh = box1[:, :, :, :, :, 2:4] + box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) + box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) + + box2_xy = box2[:, :, :, :, :, :2] + box2_wh = box2[:, :, :, :, :, 2:4] + box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0) + box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0) + + intersect_mins = self.max(box1_mins, box2_mins) + intersect_maxs = self.min(box1_maxs, box2_maxs) + intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0)) + + intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \ + P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2]) + box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2]) + box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2]) + + iou = intersect_area / (box1_area + box2_area - intersect_area) + return iou + + +class YoloLossBlock(nn.Cell): + """ + YOLOv3 Loss block cell. It will finally output loss of the scale. + + Args: + scale (str): Three scale here, 's', 'm' and 'l'. + config (Class): The default config of YOLOv3. + + Returns: + Tensor, loss of the scale. + + Examples: + YoloLossBlock('l', ConfigYOLOV3ResNet18()). + """ + + def __init__(self, scale, config): + super(YoloLossBlock, self).__init__() + self.config = config + if scale == 's': + idx = (0, 1, 2) + elif scale == 'm': + idx = (3, 4, 5) + elif scale == 'l': + idx = (6, 7, 8) + else: + raise KeyError("Invalid scale value for DetectionBlock") + self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32) + self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32) + self.concat = P.Concat(axis=-1) + self.iou = Iou() + self.cross_entropy = P.SigmoidCrossEntropyWithLogits() + self.reduce_sum = P.ReduceSum() + self.reduce_max = P.ReduceMax(keep_dims=False) + self.input_shape = Tensor(tuple(config.img_shape[::-1]), ms.float32) + + def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box): + + object_mask = y_true[:, :, :, :, 4:5] + class_probs = y_true[:, :, :, :, 5:] + + grid_shape = P.Shape()(prediction)[1:3] + grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32) + + pred_boxes = self.concat((pred_xy, pred_wh)) + true_xy = y_true[:, :, :, :, :2] * grid_shape - grid + true_wh = y_true[:, :, :, :, 2:4] + true_wh = P.Select()(P.Equal()(true_wh, 0.0), + P.Fill()(P.DType()(true_wh), P.Shape()(true_wh), 1.0), + true_wh) + true_wh = P.Log()(true_wh / self.anchors * self.input_shape) + box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4] + + gt_shape = P.Shape()(gt_box) + gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2])) + + iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box) # [batch, grid[0], grid[1], num_anchor, num_gt] + best_iou = self.reduce_max(iou, -1) # [batch, grid[0], grid[1], num_anchor] + ignore_mask = best_iou < self.ignore_threshold + ignore_mask = P.Cast()(ignore_mask, ms.float32) + ignore_mask = P.ExpandDims()(ignore_mask, -1) + ignore_mask = F.stop_gradient(ignore_mask) + + xy_loss = object_mask * box_loss_scale * self.cross_entropy(prediction[:, :, :, :, :2], true_xy) + wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - prediction[:, :, :, :, 2:4]) + confidence_loss = self.cross_entropy(prediction[:, :, :, :, 4:5], object_mask) + confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask + class_loss = object_mask * self.cross_entropy(prediction[:, :, :, :, 5:], class_probs) + + # Get smooth loss + xy_loss = self.reduce_sum(xy_loss, ()) + wh_loss = self.reduce_sum(wh_loss, ()) + confidence_loss = self.reduce_sum(confidence_loss, ()) + class_loss = self.reduce_sum(class_loss, ()) + + loss = xy_loss + wh_loss + confidence_loss + class_loss + return loss / P.Shape()(prediction)[0] + + +class yolov3_resnet18(nn.Cell): + """ + ResNet based YOLOv3 network. + + Args: + config (Class): YOLOv3 config. + + Returns: + Cell, cell instance of ResNet based YOLOv3 neural network. + + Examples: + yolov3_resnet18(80, [1,3,416,416]). + """ + + def __init__(self, config): + super(yolov3_resnet18, self).__init__() + self.config = config + + # YOLOv3 network + self.feature_map = YOLOv3(feature_shape=self.config.feature_shape, + backbone=ResNet(BasicBlock, + self.config.backbone_layers, + self.config.backbone_input_shape, + self.config.backbone_shape, + self.config.backbone_stride, + num_classes=None), + backbone_shape=self.config.backbone_shape, + out_channel=self.config.out_channel) + + # prediction on the default anchor boxes + self.detect_1 = DetectionBlock('l', self.config) + self.detect_2 = DetectionBlock('m', self.config) + self.detect_3 = DetectionBlock('s', self.config) + + def construct(self, x): + big_object_output, medium_object_output, small_object_output = self.feature_map(x) + output_big = self.detect_1(big_object_output) + output_me = self.detect_2(medium_object_output) + output_small = self.detect_3(small_object_output) + + return output_big, output_me, output_small + + +class YoloWithLossCell(nn.Cell): + """" + Provide YOLOv3 training loss through network. + + Args: + network (Cell): The training network. + config (Class): YOLOv3 config. + + Returns: + Tensor, the loss of the network. + """ + def __init__(self, network, config): + super(YoloWithLossCell, self).__init__() + self.yolo_network = network + self.config = config + self.loss_big = YoloLossBlock('l', self.config) + self.loss_me = YoloLossBlock('m', self.config) + self.loss_small = YoloLossBlock('s', self.config) + + def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2): + yolo_out = self.yolo_network(x) + loss_l = self.loss_big(yolo_out[0][0], yolo_out[0][1], yolo_out[0][2], yolo_out[0][3], y_true_0, gt_0) + loss_m = self.loss_me(yolo_out[1][0], yolo_out[1][1], yolo_out[1][2], yolo_out[1][3], y_true_1, gt_1) + loss_s = self.loss_small(yolo_out[2][0], yolo_out[2][1], yolo_out[2][2], yolo_out[2][3], y_true_2, gt_2) + return loss_l + loss_m + loss_s + + +class TrainingWrapper(nn.Cell): + """ + Encapsulation class of YOLOv3 network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + """ + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.weights = ms.ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("mirror_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) + + +class YoloBoxScores(nn.Cell): + """ + Calculate the boxes of the original picture size and the score of each box. + + Args: + config (Class): YOLOv3 config. + + Returns: + Tensor, the boxes of the original picture size. + Tensor, the score of each box. + """ + def __init__(self, config): + super(YoloBoxScores, self).__init__() + self.input_shape = Tensor(np.array(config.img_shape), ms.float32) + self.num_classes = config.num_classes + + def construct(self, box_xy, box_wh, box_confidence, box_probs, image_shape): + batch_size = F.shape(box_xy)[0] + x = box_xy[:, :, :, :, 0:1] + y = box_xy[:, :, :, :, 1:2] + box_yx = P.Concat(-1)((y, x)) + w = box_wh[:, :, :, :, 0:1] + h = box_wh[:, :, :, :, 1:2] + box_hw = P.Concat(-1)((h, w)) + + new_shape = P.Round()(image_shape * P.ReduceMin()(self.input_shape / image_shape)) + offset = (self.input_shape - new_shape) / 2.0 / self.input_shape + scale = self.input_shape / new_shape + box_yx = (box_yx - offset) * scale + box_hw = box_hw * scale + + box_min = box_yx - box_hw / 2.0 + box_max = box_yx + box_hw / 2.0 + boxes = P.Concat(-1)((box_min[:, :, :, :, 0:1], + box_min[:, :, :, :, 1:2], + box_max[:, :, :, :, 0:1], + box_max[:, :, :, :, 1:2])) + image_scale = P.Tile()(image_shape, (1, 2)) + boxes = boxes * image_scale + boxes = F.reshape(boxes, (batch_size, -1, 4)) + boxes_scores = box_confidence * box_probs + boxes_scores = F.reshape(boxes_scores, (batch_size, -1, self.num_classes)) + return boxes, boxes_scores + + +class YoloWithEval(nn.Cell): + """ + Encapsulation class of YOLOv3 evaluation. + + Args: + network (Cell): The training network. Note that loss function and optimizer must not be added. + config (Class): YOLOv3 config. + + Returns: + Tensor, the boxes of the original picture size. + Tensor, the score of each box. + Tensor, the original picture size. + """ + def __init__(self, network, config): + super(YoloWithEval, self).__init__() + self.yolo_network = network + self.box_score_0 = YoloBoxScores(config) + self.box_score_1 = YoloBoxScores(config) + self.box_score_2 = YoloBoxScores(config) + + def construct(self, x, image_shape): + yolo_output = self.yolo_network(x) + boxes_0, boxes_scores_0 = self.box_score_0(*yolo_output[0], image_shape) + boxes_1, boxes_scores_1 = self.box_score_1(*yolo_output[1], image_shape) + boxes_2, boxes_scores_2 = self.box_score_2(*yolo_output[2], image_shape) + boxes = P.Concat(1)((boxes_0, boxes_1, boxes_2)) + boxes_scores = P.Concat(1)((boxes_scores_0, boxes_scores_1, boxes_scores_2)) + return boxes, boxes_scores, image_shape diff --git a/tests/st/model_zoo_tests/yolov3/test_yolov3.py b/tests/st/model_zoo_tests/yolov3/test_yolov3.py new file mode 100644 index 0000000000..6b4057db18 --- /dev/null +++ b/tests/st/model_zoo_tests/yolov3/test_yolov3.py @@ -0,0 +1,157 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +######################## train YOLOv3 example ######################## +train YOLOv3 and get network model files(.ckpt) : +python train.py --image_dir /data --anno_path /data/coco/train_coco.txt --mindrecord_dir=/data/Mindrecord_train + +If the mindrecord_dir is empty, it wil generate mindrecord file by image_dir and anno_path. +Note if mindrecord_dir isn't empty, it will use mindrecord_dir rather than image_dir and anno_path. +""" + +import os +import time +import pytest +import numpy as np +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.train import Model +from mindspore.common.initializer import initializer +from mindspore.train.callback import Callback + +from src.yolov3 import yolov3_resnet18, YoloWithLossCell, TrainingWrapper +from src.dataset import create_yolo_dataset +from src.config import ConfigYOLOV3ResNet18 + +np.random.seed(1) +def get_lr(learning_rate, start_step, global_step, decay_step, decay_rate, steps=False): + """Set learning rate.""" + lr_each_step = [] + for i in range(global_step): + if steps: + lr_each_step.append(learning_rate * (decay_rate ** (i // decay_step))) + else: + lr_each_step.append(learning_rate * (decay_rate ** (i / decay_step))) + lr_each_step = np.array(lr_each_step).astype(np.float32) + lr_each_step = lr_each_step[start_step:] + return lr_each_step + + +def init_net_param(network, init_value='ones'): + """Init:wq the parameters in network.""" + params = network.trainable_params() + for p in params: + if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: + p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype())) + +class ModelCallback(Callback): + def __init__(self): + super(ModelCallback, self).__init__() + self.loss_list = [] + + def step_end(self, run_context): + cb_params = run_context.original_args() + self.loss_list.append(cb_params.net_outputs.asnumpy()) + print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) + +class TimeMonitor(Callback): + """Time Monitor.""" + def __init__(self, data_size): + super(TimeMonitor, self).__init__() + self.data_size = data_size + self.epoch_mseconds_list = [] + self.per_step_mseconds_list = [] + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + self.epoch_mseconds_list.append(epoch_mseconds) + self.per_step_mseconds_list.append(epoch_mseconds / self.data_size) + +DATA_DIR = "/home/workspace/mindspore_dataset/coco/coco2017/mindrecord_train/yolov3" + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_yolov3(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + rank = 0 + device_num = 1 + lr_init = 0.001 + epoch_size = 3 + batch_size = 32 + loss_scale = 1024 + mindrecord_dir = DATA_DIR + + # It will generate mindrecord file in args_opt.mindrecord_dir, + # and the file name is yolo.mindrecord0, 1, ... file_num. + if not os.path.isdir(mindrecord_dir): + raise KeyError("mindrecord path is not exist.") + + prefix = "yolo.mindrecord" + mindrecord_file = os.path.join(mindrecord_dir, prefix + "0") + print("yolov3 mindrecord is ", mindrecord_file) + if not os.path.exists(mindrecord_file): + print("mindrecord file is not exist.") + assert False + else: + loss_scale = float(loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0. + dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size, + batch_size=batch_size, device_num=device_num, rank=rank) + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + net = yolov3_resnet18(ConfigYOLOV3ResNet18()) + net = YoloWithLossCell(net, ConfigYOLOV3ResNet18()) + init_net_param(net) + + total_epoch_size = 60 + lr = Tensor(get_lr(learning_rate=lr_init, start_step=0, + global_step=total_epoch_size * dataset_size, + decay_step=1000, decay_rate=0.95, steps=True)) + opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) + net = TrainingWrapper(net, opt, loss_scale) + + model_callback = ModelCallback() + time_monitor_callback = TimeMonitor(data_size=dataset_size) + callback = [model_callback, time_monitor_callback] + + model = Model(net) + print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") + model.train(epoch_size, dataset, callbacks=callback, dataset_sink_mode=True) + # assertion occurs while the loss value, overflow state or loss_scale value is wrong + loss_value = np.array(model_callback.loss_list) + + expect_loss_value = [6600, 4200, 2700] + print("loss value: {}".format(loss_value)) + assert loss_value[0] < expect_loss_value[0] + assert loss_value[1] < expect_loss_value[1] + assert loss_value[2] < expect_loss_value[2] + + epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] + expect_epoch_mseconds = 950 + print("epoch mseconds: {}".format(epoch_mseconds)) + assert epoch_mseconds <= expect_epoch_mseconds + + per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] + expect_per_step_mseconds = 110 + print("per step mseconds: {}".format(per_step_mseconds)) + assert per_step_mseconds <= expect_per_step_mseconds + print("yolov3 test case passed.")