upload yolov3-darknet53 quant code

pull/4319/head
chengxianbin 5 years ago
parent 0df5a56159
commit 759392571f

@ -29,7 +29,7 @@ from mindspore._checkparam import Rel
import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU
from .activation import get_activation, ReLU, LeakyReLU
from ..cell import Cell
from . import conv, basic
from ..._checkparam import ParamValidator as validator
@ -115,7 +115,11 @@ class Conv2dBnAct(Cell):
weight_init='normal',
bias_init='zeros',
has_bn=False,
activation=None):
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(Conv2dBnAct, self).__init__()
if context.get_context('device_target') == "Ascend" and group > 1:
@ -145,9 +149,13 @@ class Conv2dBnAct(Cell):
self.has_bn = validator.check_bool("has_bn", has_bn)
self.has_act = activation is not None
self.after_fake = after_fake
if has_bn:
self.batchnorm = BatchNorm2d(out_channels)
self.activation = get_activation(activation)
self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation)
def construct(self, x):
x = self.conv(x)

@ -244,7 +244,7 @@ class ConvertToQuantNetwork:
subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
@ -274,7 +274,7 @@ class ConvertToQuantNetwork:
subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,

@ -0,0 +1,143 @@
# YOLOV3-DarkNet53-Quant Example
## Description
This is an example of training YOLOV3-DarkNet53-Quant with COCO2014 dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Download the dataset COCO2014.
> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows:
```
.
└─dataset
├─train2014
├─val2014
└─annotations
```
## Structure
```shell
.
└─yolov3_darknet53_quant
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training(1p)
├─run_distribute_train.sh # launch distributed training(8p)
└─run_eval.sh # launch evaluating
├─src
├─__init__.py # python init file
├─config.py # parameter configuration
├─darknet.py # backbone of network
├─distributed_sampler.py # iterator of dataset
├─initializer.py # initializer of parameters
├─logger.py # log function
├─loss.py # loss function
├─lr_scheduler.py # generate learning rate
├─transforms.py # Preprocess data
├─util.py # util function
├─yolo.py # yolov3 network
├─yolo_dataset.py # create dataset for YOLOV3
├─eval.py # eval net
└─train.py # train net
```
## Running the example
### Train
#### Usage
```
# distributed training
sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]
# standalone training
sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]
```
#### Launch
```bash
# distributed training example(8p)
sh run_distribute_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt rank_table_8p.json
# standalone training example(1p)
sh run_standalone_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
#### Result
Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt.
```
# distribute training result(8p)
epoch[0], iter[0], loss:483.341675, 0.31 imgs/sec, lr:0.0
epoch[0], iter[100], loss:55.690952, 3.46 imgs/sec, lr:0.0
epoch[0], iter[200], loss:54.045728, 126.54 imgs/sec, lr:0.0
epoch[0], iter[300], loss:48.771608, 133.04 imgs/sec, lr:0.0
epoch[0], iter[400], loss:48.486769, 139.69 imgs/sec, lr:0.0
epoch[0], iter[500], loss:48.649275, 143.29 imgs/sec, lr:0.0
epoch[0], iter[600], loss:44.731309, 144.03 imgs/sec, lr:0.0
epoch[1], iter[700], loss:43.037023, 136.08 imgs/sec, lr:0.0
epoch[1], iter[800], loss:41.514788, 132.94 imgs/sec, lr:0.0
epoch[133], iter[85700], loss:33.326716, 136.14 imgs/sec, lr:6.497331924038008e-06
epoch[133], iter[85800], loss:34.968744, 136.76 imgs/sec, lr:6.497331924038008e-06
epoch[134], iter[85900], loss:35.868543, 137.08 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86000], loss:35.740817, 139.49 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86100], loss:34.600463, 141.47 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86200], loss:36.641916, 137.91 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86300], loss:32.819769, 138.17 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86400], loss:35.603033, 142.23 imgs/sec, lr:1.6245529650404933e-06
epoch[134], iter[86500], loss:34.303755, 145.18 imgs/sec, lr:1.6245529650404933e-06
...
```
### Infer
#### Usage
```
# infer
sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch
```bash
# infer with checkpoint
sh run_eval.sh dataset/coco2014/ checkpoint/0-135.ckpt
```
> checkpoint can be produced in training process.
#### Result
Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt.
```
=============coco eval reulst=========
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.310
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.531
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.130
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.260
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.402
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.429
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.232
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558
```

File diff suppressed because it is too large Load Diff

@ -0,0 +1,83 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
RESUME_YOLOV3=$(get_real_path $2)
MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3)
echo $DATASET_PATH
echo $RESUME_YOLOV3
echo $MINDSPORE_HCCL_CONFIG_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $RESUME_YOLOV3 ]
then
echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file"
exit 1
fi
if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ]
then
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file"
exit 1
fi
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--resume_yolov3=$RESUME_YOLOV3 \
--is_distributed=1 \
--per_batch_size=16 \
--lr=0.012 \
--T_max=135 \
--max_epoch=135 \
--warmup_epochs=5 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
done

@ -0,0 +1,67 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
CHECKPOINT_PATH=$(get_real_path $2)
echo $DATASET_PATH
echo $CHECKPOINT_PATH
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $CHECKPOINT_PATH ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start infering for device $DEVICE_ID"
python eval.py \
--data_dir=$DATASET_PATH \
--pretrained=$CHECKPOINT_PATH \
--testing_shape=416 > log.txt 2>&1 &
cd ..

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH
RESUME_YOLOV3=$(get_real_path $2)
echo $RESUME_YOLOV3
if [ ! -d $DATASET_PATH ]
then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ ! -f $RESUME_YOLOV3 ]
then
echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file"
exit 1
fi
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py \
--data_dir=$DATASET_PATH \
--resume_yolov3=$RESUME_YOLOV3 \
--is_distributed=0 \
--per_batch_size=16 \
--lr=0.004 \
--T_max=135 \
--max_epoch=135 \
--warmup_epochs=5 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Config parameters for Darknet based yolov3_darknet53 models."""
class ConfigYOLOV3DarkNet53:
"""
Config parameters for the yolov3_darknet53.
Examples:
ConfigYOLOV3DarkNet53()
"""
# train_param
# data augmentation related
hue = 0.1
saturation = 1.5
value = 1.5
jitter = 0.3
resize_rate = 1
multi_scale = [[320, 320],
[352, 352],
[384, 384],
[416, 416],
[448, 448],
[480, 480],
[512, 512],
[544, 544],
[576, 576],
[608, 608]
]
num_classes = 80
max_box = 50
backbone_input_shape = [32, 64, 128, 256, 512]
backbone_shape = [64, 128, 256, 512, 1024]
backbone_layers = [1, 2, 8, 8, 4]
# confidence under ignore_threshold means no object when training
ignore_threshold = 0.7
# h->w
anchor_scales = [(10, 13),
(16, 30),
(33, 23),
(30, 61),
(62, 45),
(59, 119),
(116, 90),
(156, 198),
(373, 326)]
out_channel = 255
quantization_aware = True
# test_param
test_img_shape = [416, 416]

@ -0,0 +1,208 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DarkNet model."""
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv_block(in_channels,
out_channels,
kernel_size,
stride,
dilation=1):
"""Get a conv2d batchnorm and relu layer"""
pad_mode = 'same'
padding = 0
return nn.Conv2dBnAct(in_channels, out_channels, kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
has_bn=True,
momentum=0.1,
activation='relu')
class ResidualBlock(nn.Cell):
"""
DarkNet V1 residual block definition.
Args:
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
Returns:
Tensor, output tensor.
Examples:
ResidualBlock(3, 208)
"""
expansion = 4
def __init__(self,
in_channels,
out_channels):
super(ResidualBlock, self).__init__()
out_chls = out_channels//2
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
self.add = P.TensorAdd()
def construct(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.add(out, identity)
return out
class DarkNet(nn.Cell):
"""
DarkNet V1 network.
Args:
block: Cell. Block for network.
layer_nums: List. Numbers of different layers.
in_channels: Integer. Input channel.
out_channels: Integer. Output channel.
detect: Bool. Whether detect or not. Default:False.
Returns:
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
Examples:
DarkNet(ResidualBlock,
[1, 2, 8, 8, 4],
[32, 64, 128, 256, 512],
[64, 128, 256, 512, 1024],
100)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
detect=False):
super(DarkNet, self).__init__()
self.outchannel = out_channels[-1]
self.detect = detect
if not len(layer_nums) == len(in_channels) == len(out_channels) == 5:
raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!")
self.conv0 = conv_block(3,
in_channels[0],
kernel_size=3,
stride=1)
self.conv1 = conv_block(in_channels[0],
out_channels[0],
kernel_size=3,
stride=2)
self.conv2 = conv_block(in_channels[1],
out_channels[1],
kernel_size=3,
stride=2)
self.conv3 = conv_block(in_channels[2],
out_channels[2],
kernel_size=3,
stride=2)
self.conv4 = conv_block(in_channels[3],
out_channels[3],
kernel_size=3,
stride=2)
self.conv5 = conv_block(in_channels[4],
out_channels[4],
kernel_size=3,
stride=2)
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=out_channels[0],
out_channel=out_channels[0])
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=out_channels[1],
out_channel=out_channels[1])
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=out_channels[2],
out_channel=out_channels[2])
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=out_channels[3],
out_channel=out_channels[3])
self.layer5 = self._make_layer(block,
layer_nums[4],
in_channel=out_channels[4],
out_channel=out_channels[4])
def _make_layer(self, block, layer_num, in_channel, out_channel):
"""
Make Layer for DarkNet.
:param block: Cell. DarkNet block.
:param layer_num: Integer. Layer number.
:param in_channel: Integer. Input channel.
:param out_channel: Integer. Output channel.
Examples:
_make_layer(ConvBlock, 1, 128, 256)
"""
layers = []
darkblk = block(in_channel, out_channel)
layers.append(darkblk)
for _ in range(1, layer_num):
darkblk = block(out_channel, out_channel)
layers.append(darkblk)
return nn.SequentialCell(layers)
def construct(self, x):
c1 = self.conv0(x)
c2 = self.conv1(c1)
c3 = self.layer1(c2)
c4 = self.conv2(c3)
c5 = self.layer2(c4)
c6 = self.conv3(c5)
c7 = self.layer3(c6)
c8 = self.conv4(c7)
c9 = self.layer4(c8)
c10 = self.conv5(c9)
c11 = self.layer5(c10)
if self.detect:
return c7, c9, c11
return c11
def get_out_channels(self):
return self.outchannel
def darknet53():
"""
Get DarkNet53 neural network.
Returns:
Cell, cell instance of DarkNet53 neural network.
Examples:
darknet53()
"""
return DarkNet(ResidualBlock, [1, 2, 8, 8, 4],
[32, 64, 128, 256, 512],
[64, 128, 256, 512, 1024])

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Yolo dataset distributed sampler."""
from __future__ import division
import math
import numpy as np
class DistributedSampler:
"""Distributed sampler."""
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
print("***********Setting world_size to 1 since it is not passed in ******************")
num_replicas = 1
if rank is None:
print("***********Setting rank to 0 since it is not passed in ******************")
rank = 0
self.dataset_size = dataset_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
indices = indices.tolist()
self.epoch += 1
# change to list type
else:
indices = list(range(self.dataset_size))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples

@ -0,0 +1,179 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parameter init."""
import math
import numpy as np
from mindspore.common import initializer as init
from mindspore.common.initializer import Initializer as MeInitializer
import mindspore.nn as nn
from mindspore import Tensor
np.random.seed(5)
def calculate_gain(nonlinearity, param=None):
r"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of 'num' and 'arr'."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_correct_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
Examples:
>>> w = np.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
"""
fan = _calculate_correct_fan(arr, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, arr.shape)
def _calculate_fan_in_and_fan_out(arr):
"""Calculate fan in and fan out."""
dimensions = len(arr.shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
num_input_fmaps = arr.shape[1]
num_output_fmaps = arr.shape[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = arr[0][0].size
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class KaimingUniform(MeInitializer):
"""Kaiming uniform initializer."""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingUniform, self).__init__()
self.a = a
self.mode = mode
self.nonlinearity = nonlinearity
def _initialize(self, arr):
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
_assignment(arr, tmp)
def default_recurisive_init(custom_cell):
"""Initialize parameter."""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, nn.Dense):
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.default_input.shape,
cell.weight.default_input.dtype).to_tensor()
if cell.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
bound = 1 / math.sqrt(fan_in)
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
cell.bias.default_input.dtype)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass

@ -0,0 +1,80 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
"""Get Logger."""
logger = LOGGER('yolov3_darknet53', rank)
logger.setup_logging_file(path, rank)
return logger

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV3 loss."""
from mindspore.ops import operations as P
import mindspore.nn as nn
class XYLoss(nn.Cell):
"""Loss for x and y."""
def __init__(self):
super(XYLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, box_loss_scale, predict_xy, true_xy):
xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy)
xy_loss = self.reduce_sum(xy_loss, ())
return xy_loss
class WHLoss(nn.Cell):
"""Loss for w and h."""
def __init__(self):
super(WHLoss, self).__init__()
self.square = P.Square()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, box_loss_scale, predict_wh, true_wh):
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh)
wh_loss = self.reduce_sum(wh_loss, ())
return wh_loss
class ConfidenceLoss(nn.Cell):
"""Loss for confidence."""
def __init__(self):
super(ConfidenceLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, predict_confidence, ignore_mask):
confidence_loss = self.cross_entropy(predict_confidence, object_mask)
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
confidence_loss = self.reduce_sum(confidence_loss, ())
return confidence_loss
class ClassLoss(nn.Cell):
"""Loss for classification."""
def __init__(self):
super(ClassLoss, self).__init__()
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
self.reduce_sum = P.ReduceSum()
def construct(self, object_mask, predict_class, class_probs):
class_loss = object_mask * self.cross_entropy(predict_class, class_probs)
class_loss = self.reduce_sum(class_loss, ())
return class_loss

@ -0,0 +1,143 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate scheduler."""
import math
from collections import Counter
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""Linear learning rate."""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""Warmup step learning rate."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Cosine annealing learning rate."""
base_lr = lr
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = 0
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Cosine annealing learning rate V2."""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
last_lr = 0
last_epoch_V1 = 0
T_max_V2 = int(max_epoch*1/3)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
if i < total_steps*2/3:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
last_lr = lr
last_epoch_V1 = last_epoch
else:
base_lr = last_lr
last_epoch = last_epoch-last_epoch_V1
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""Warmup cosine annealing learning rate."""
start_sample_epoch = 60
step_sample = 2
tobe_sampled_epoch = 60
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
max_sampled_epoch = max_epoch+tobe_sampled_epoch
T_max = max_sampled_epoch
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_sampled_steps):
last_epoch = i // steps_per_epoch
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
continue
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
lr_each_step.append(lr)
assert total_steps == len(lr_each_step)
return np.array(lr_each_step).astype(np.float32)

@ -0,0 +1,177 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
from mindspore.train.serialization import load_checkpoint
import mindspore.nn as nn
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def load_backbone(net, ckpt_path, args):
"""Load darknet53 backbone checkpoint."""
param_dict = load_checkpoint(ckpt_path)
yolo_backbone_prefix = 'feature_map.backbone'
darknet_backbone_prefix = 'network.backbone'
find_param = []
not_found_param = []
for name, cell in net.cells_and_names():
if name.startswith(yolo_backbone_prefix):
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
if isinstance(cell, (nn.Conv2d, nn.Dense)):
darknet_weight = '{}.weight'.format(name)
darknet_bias = '{}.bias'.format(name)
if darknet_weight in param_dict:
cell.weight.default_input = param_dict[darknet_weight].data
find_param.append(darknet_weight)
else:
not_found_param.append(darknet_weight)
if darknet_bias in param_dict:
cell.bias.default_input = param_dict[darknet_bias].data
find_param.append(darknet_bias)
else:
not_found_param.append(darknet_bias)
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
darknet_moving_mean = '{}.moving_mean'.format(name)
darknet_moving_variance = '{}.moving_variance'.format(name)
darknet_gamma = '{}.gamma'.format(name)
darknet_beta = '{}.beta'.format(name)
if darknet_moving_mean in param_dict:
cell.moving_mean.default_input = param_dict[darknet_moving_mean].data
find_param.append(darknet_moving_mean)
else:
not_found_param.append(darknet_moving_mean)
if darknet_moving_variance in param_dict:
cell.moving_variance.default_input = param_dict[darknet_moving_variance].data
find_param.append(darknet_moving_variance)
else:
not_found_param.append(darknet_moving_variance)
if darknet_gamma in param_dict:
cell.gamma.default_input = param_dict[darknet_gamma].data
find_param.append(darknet_gamma)
else:
not_found_param.append(darknet_gamma)
if darknet_beta in param_dict:
cell.beta.default_input = param_dict[darknet_beta].data
find_param.append(darknet_beta)
else:
not_found_param.append(darknet_beta)
args.logger.info('================found_param {}========='.format(len(find_param)))
args.logger.info(find_param)
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
args.logger.info(not_found_param)
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
return net
def default_wd_filter(x):
"""default weight decay filter."""
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
return False
if parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
return False
if parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
return False
return True
def get_param_groups(network):
"""Param groups for optimizer."""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
class ShapeRecord:
"""Log image shape."""
def __init__(self):
self.shape_record = {
320: 0,
352: 0,
384: 0,
416: 0,
448: 0,
480: 0,
512: 0,
544: 0,
576: 0,
608: 0,
'total': 0
}
def set(self, shape):
if len(shape) > 1:
shape = shape[0]
shape = int(shape)
self.shape_record[shape] += 1
self.shape_record['total'] += 1
def show(self, logger):
for key in self.shape_record:
rate = self.shape_record[key] / float(self.shape_record['total'])
logger.info('shape {}: {:.2f}%'.format(key, rate*100))

File diff suppressed because it is too large Load Diff

@ -0,0 +1,184 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""YOLOV3 dataset."""
import os
from PIL import Image
from pycocotools.coco import COCO
import mindspore.dataset as de
import mindspore.dataset.transforms.vision.c_transforms as CV
from src.distributed_sampler import DistributedSampler
from src.transforms import reshape_fn, MultiScaleTrans
min_keypoints_per_image = 10
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
def has_valid_annotation(anno):
"""Check annotation file."""
# if it's empty, there is no annotation
if not anno:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
class COCOYoloDataset:
"""YOLOV3 Dataset for COCO."""
def __init__(self, root, ann_file, remove_images_without_annotations=True,
filter_crowd_anno=True, is_training=True):
self.coco = COCO(ann_file)
self.root = root
self.img_ids = list(sorted(self.coco.imgs.keys()))
self.filter_crowd_anno = filter_crowd_anno
self.is_training = is_training
# filter images without any annotations
if remove_images_without_annotations:
img_ids = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
img_ids.append(img_id)
self.img_ids = img_ids
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
self.cat_ids_to_continuous_ids = {
v: i for i, v in enumerate(self.coco.getCatIds())
}
self.continuous_ids_cat_ids = {
v: k for k, v in self.cat_ids_to_continuous_ids.items()
}
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
generated by the image's annotation. img is a PIL image.
"""
coco = self.coco
img_id = self.img_ids[index]
img_path = coco.loadImgs(img_id)[0]["file_name"]
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
if not self.is_training:
return img, img_id
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
# filter crowd annotations
if self.filter_crowd_anno:
annos = [anno for anno in target if anno["iscrowd"] == 0]
else:
annos = [anno for anno in target]
target = {}
boxes = [anno["bbox"] for anno in annos]
target["bboxes"] = boxes
classes = [anno["category_id"] for anno in annos]
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
target["labels"] = classes
bboxes = target['bboxes']
labels = target['labels']
out_target = []
for bbox, label in zip(bboxes, labels):
tmp = []
# convert to [x_min y_min x_max y_max]
bbox = self._convetTopDown(bbox)
tmp.extend(bbox)
tmp.append(int(label))
# tmp [x_min y_min x_max y_max, label]
out_target.append(tmp)
return img, out_target
def __len__(self):
return len(self.img_ids)
def _convetTopDown(self, bbox):
x_min = bbox[0]
y_min = bbox[1]
w = bbox[2]
h = bbox[3]
return [x_min, y_min, x_min+w, y_min+h]
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV3."""
if is_training:
filter_crowd = True
remove_empty_anno = True
else:
filter_crowd = False
remove_empty_anno = False
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
hwc_to_chw = CV.HWC2CHW()
config.dataset_size = len(yolo_dataset)
num_parallel_workers1 = int(64 / device_num)
num_parallel_workers2 = int(16 / device_num)
if is_training:
multi_scale_trans = MultiScaleTrans(config, device_num)
if device_num != 8:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"],
num_parallel_workers=num_parallel_workers1,
sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
num_parallel_workers=num_parallel_workers2, drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
num_parallel_workers=8, drop_remainder=True)
else:
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
sampler=distributed_sampler)
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
ds = ds.map(input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
columns_order=["image", "image_shape", "img_id"],
operations=compose_map_func, num_parallel_workers=8)
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(max_epoch)
return ds, len(yolo_dataset)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save