parent
0df5a56159
commit
759392571f
@ -0,0 +1,143 @@
|
||||
# YOLOV3-DarkNet53-Quant Example
|
||||
|
||||
## Description
|
||||
|
||||
This is an example of training YOLOV3-DarkNet53-Quant with COCO2014 dataset in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Download the dataset COCO2014.
|
||||
|
||||
> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows:
|
||||
|
||||
```
|
||||
.
|
||||
└─dataset
|
||||
├─train2014
|
||||
├─val2014
|
||||
└─annotations
|
||||
```
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─yolov3_darknet53_quant
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train.sh # launch standalone training(1p)
|
||||
├─run_distribute_train.sh # launch distributed training(8p)
|
||||
└─run_eval.sh # launch evaluating
|
||||
├─src
|
||||
├─__init__.py # python init file
|
||||
├─config.py # parameter configuration
|
||||
├─darknet.py # backbone of network
|
||||
├─distributed_sampler.py # iterator of dataset
|
||||
├─initializer.py # initializer of parameters
|
||||
├─logger.py # log function
|
||||
├─loss.py # loss function
|
||||
├─lr_scheduler.py # generate learning rate
|
||||
├─transforms.py # Preprocess data
|
||||
├─util.py # util function
|
||||
├─yolo.py # yolov3 network
|
||||
├─yolo_dataset.py # create dataset for YOLOV3
|
||||
├─eval.py # eval net
|
||||
└─train.py # train net
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# distributed training
|
||||
sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]
|
||||
|
||||
# standalone training
|
||||
sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# distributed training example(8p)
|
||||
sh run_distribute_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt rank_table_8p.json
|
||||
|
||||
# standalone training example(1p)
|
||||
sh run_standalone_train.sh dataset/coco2014 yolov3_darknet_noquant_ckpt/0-320_102400.ckpt
|
||||
```
|
||||
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||
|
||||
#### Result
|
||||
|
||||
Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt.
|
||||
|
||||
```
|
||||
# distribute training result(8p)
|
||||
epoch[0], iter[0], loss:483.341675, 0.31 imgs/sec, lr:0.0
|
||||
epoch[0], iter[100], loss:55.690952, 3.46 imgs/sec, lr:0.0
|
||||
epoch[0], iter[200], loss:54.045728, 126.54 imgs/sec, lr:0.0
|
||||
epoch[0], iter[300], loss:48.771608, 133.04 imgs/sec, lr:0.0
|
||||
epoch[0], iter[400], loss:48.486769, 139.69 imgs/sec, lr:0.0
|
||||
epoch[0], iter[500], loss:48.649275, 143.29 imgs/sec, lr:0.0
|
||||
epoch[0], iter[600], loss:44.731309, 144.03 imgs/sec, lr:0.0
|
||||
epoch[1], iter[700], loss:43.037023, 136.08 imgs/sec, lr:0.0
|
||||
epoch[1], iter[800], loss:41.514788, 132.94 imgs/sec, lr:0.0
|
||||
|
||||
…
|
||||
epoch[133], iter[85700], loss:33.326716, 136.14 imgs/sec, lr:6.497331924038008e-06
|
||||
epoch[133], iter[85800], loss:34.968744, 136.76 imgs/sec, lr:6.497331924038008e-06
|
||||
epoch[134], iter[85900], loss:35.868543, 137.08 imgs/sec, lr:1.6245529650404933e-06
|
||||
epoch[134], iter[86000], loss:35.740817, 139.49 imgs/sec, lr:1.6245529650404933e-06
|
||||
epoch[134], iter[86100], loss:34.600463, 141.47 imgs/sec, lr:1.6245529650404933e-06
|
||||
epoch[134], iter[86200], loss:36.641916, 137.91 imgs/sec, lr:1.6245529650404933e-06
|
||||
epoch[134], iter[86300], loss:32.819769, 138.17 imgs/sec, lr:1.6245529650404933e-06
|
||||
epoch[134], iter[86400], loss:35.603033, 142.23 imgs/sec, lr:1.6245529650404933e-06
|
||||
epoch[134], iter[86500], loss:34.303755, 145.18 imgs/sec, lr:1.6245529650404933e-06
|
||||
...
|
||||
```
|
||||
|
||||
### Infer
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# infer
|
||||
sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# infer with checkpoint
|
||||
sh run_eval.sh dataset/coco2014/ checkpoint/0-135.ckpt
|
||||
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
|
||||
#### Result
|
||||
|
||||
Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt.
|
||||
|
||||
```
|
||||
=============coco eval reulst=========
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.310
|
||||
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.531
|
||||
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.130
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326
|
||||
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.425
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.260
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.402
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.429
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.232
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.450
|
||||
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.558
|
||||
```
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,83 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [RESUME_YOLOV3] [MINDSPORE_HCCL_CONFIG_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
RESUME_YOLOV3=$(get_real_path $2)
|
||||
MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3)
|
||||
|
||||
echo $DATASET_PATH
|
||||
echo $RESUME_YOLOV3
|
||||
echo $MINDSPORE_HCCL_CONFIG_PATH
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $RESUME_YOLOV3 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ]
|
||||
then
|
||||
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--resume_yolov3=$RESUME_YOLOV3 \
|
||||
--is_distributed=1 \
|
||||
--per_batch_size=16 \
|
||||
--lr=0.012 \
|
||||
--T_max=135 \
|
||||
--max_epoch=135 \
|
||||
--warmup_epochs=5 \
|
||||
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
CHECKPOINT_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
python eval.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained=$CHECKPOINT_PATH \
|
||||
--testing_shape=416 > log.txt 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [RESUME_YOLOV3]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
RESUME_YOLOV3=$(get_real_path $2)
|
||||
echo $RESUME_YOLOV3
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $RESUME_YOLOV3 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$RESUME_YOLOV3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python train.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--resume_yolov3=$RESUME_YOLOV3 \
|
||||
--is_distributed=0 \
|
||||
--per_batch_size=16 \
|
||||
--lr=0.004 \
|
||||
--T_max=135 \
|
||||
--max_epoch=135 \
|
||||
--warmup_epochs=5 \
|
||||
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Config parameters for Darknet based yolov3_darknet53 models."""
|
||||
|
||||
|
||||
class ConfigYOLOV3DarkNet53:
|
||||
"""
|
||||
Config parameters for the yolov3_darknet53.
|
||||
|
||||
Examples:
|
||||
ConfigYOLOV3DarkNet53()
|
||||
"""
|
||||
# train_param
|
||||
# data augmentation related
|
||||
hue = 0.1
|
||||
saturation = 1.5
|
||||
value = 1.5
|
||||
jitter = 0.3
|
||||
|
||||
resize_rate = 1
|
||||
multi_scale = [[320, 320],
|
||||
[352, 352],
|
||||
[384, 384],
|
||||
[416, 416],
|
||||
[448, 448],
|
||||
[480, 480],
|
||||
[512, 512],
|
||||
[544, 544],
|
||||
[576, 576],
|
||||
[608, 608]
|
||||
]
|
||||
|
||||
num_classes = 80
|
||||
max_box = 50
|
||||
|
||||
backbone_input_shape = [32, 64, 128, 256, 512]
|
||||
backbone_shape = [64, 128, 256, 512, 1024]
|
||||
backbone_layers = [1, 2, 8, 8, 4]
|
||||
|
||||
# confidence under ignore_threshold means no object when training
|
||||
ignore_threshold = 0.7
|
||||
|
||||
# h->w
|
||||
anchor_scales = [(10, 13),
|
||||
(16, 30),
|
||||
(33, 23),
|
||||
(30, 61),
|
||||
(62, 45),
|
||||
(59, 119),
|
||||
(116, 90),
|
||||
(156, 198),
|
||||
(373, 326)]
|
||||
out_channel = 255
|
||||
|
||||
quantization_aware = True
|
||||
# test_param
|
||||
test_img_shape = [416, 416]
|
@ -0,0 +1,208 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DarkNet model."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def conv_block(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation=1):
|
||||
"""Get a conv2d batchnorm and relu layer"""
|
||||
pad_mode = 'same'
|
||||
padding = 0
|
||||
|
||||
return nn.Conv2dBnAct(in_channels, out_channels, kernel_size,
|
||||
stride=stride,
|
||||
pad_mode=pad_mode,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
has_bn=True,
|
||||
momentum=0.1,
|
||||
activation='relu')
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
DarkNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
Examples:
|
||||
ResidualBlock(3, 208)
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels):
|
||||
|
||||
super(ResidualBlock, self).__init__()
|
||||
out_chls = out_channels//2
|
||||
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.add(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class DarkNet(nn.Cell):
|
||||
"""
|
||||
DarkNet V1 network.
|
||||
|
||||
Args:
|
||||
block: Cell. Block for network.
|
||||
layer_nums: List. Numbers of different layers.
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
detect: Bool. Whether detect or not. Default:False.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
|
||||
|
||||
Examples:
|
||||
DarkNet(ResidualBlock,
|
||||
[1, 2, 8, 8, 4],
|
||||
[32, 64, 128, 256, 512],
|
||||
[64, 128, 256, 512, 1024],
|
||||
100)
|
||||
"""
|
||||
def __init__(self,
|
||||
block,
|
||||
layer_nums,
|
||||
in_channels,
|
||||
out_channels,
|
||||
detect=False):
|
||||
super(DarkNet, self).__init__()
|
||||
|
||||
self.outchannel = out_channels[-1]
|
||||
self.detect = detect
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 5:
|
||||
raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!")
|
||||
self.conv0 = conv_block(3,
|
||||
in_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1)
|
||||
self.conv1 = conv_block(in_channels[0],
|
||||
out_channels[0],
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
self.conv2 = conv_block(in_channels[1],
|
||||
out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
self.conv3 = conv_block(in_channels[2],
|
||||
out_channels[2],
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
self.conv4 = conv_block(in_channels[3],
|
||||
out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
self.conv5 = conv_block(in_channels[4],
|
||||
out_channels[4],
|
||||
kernel_size=3,
|
||||
stride=2)
|
||||
|
||||
self.layer1 = self._make_layer(block,
|
||||
layer_nums[0],
|
||||
in_channel=out_channels[0],
|
||||
out_channel=out_channels[0])
|
||||
self.layer2 = self._make_layer(block,
|
||||
layer_nums[1],
|
||||
in_channel=out_channels[1],
|
||||
out_channel=out_channels[1])
|
||||
self.layer3 = self._make_layer(block,
|
||||
layer_nums[2],
|
||||
in_channel=out_channels[2],
|
||||
out_channel=out_channels[2])
|
||||
self.layer4 = self._make_layer(block,
|
||||
layer_nums[3],
|
||||
in_channel=out_channels[3],
|
||||
out_channel=out_channels[3])
|
||||
self.layer5 = self._make_layer(block,
|
||||
layer_nums[4],
|
||||
in_channel=out_channels[4],
|
||||
out_channel=out_channels[4])
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel):
|
||||
"""
|
||||
Make Layer for DarkNet.
|
||||
|
||||
:param block: Cell. DarkNet block.
|
||||
:param layer_num: Integer. Layer number.
|
||||
:param in_channel: Integer. Input channel.
|
||||
:param out_channel: Integer. Output channel.
|
||||
|
||||
Examples:
|
||||
_make_layer(ConvBlock, 1, 128, 256)
|
||||
"""
|
||||
layers = []
|
||||
darkblk = block(in_channel, out_channel)
|
||||
layers.append(darkblk)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
darkblk = block(out_channel, out_channel)
|
||||
layers.append(darkblk)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
c1 = self.conv0(x)
|
||||
c2 = self.conv1(c1)
|
||||
c3 = self.layer1(c2)
|
||||
c4 = self.conv2(c3)
|
||||
c5 = self.layer2(c4)
|
||||
c6 = self.conv3(c5)
|
||||
c7 = self.layer3(c6)
|
||||
c8 = self.conv4(c7)
|
||||
c9 = self.layer4(c8)
|
||||
c10 = self.conv5(c9)
|
||||
c11 = self.layer5(c10)
|
||||
if self.detect:
|
||||
return c7, c9, c11
|
||||
|
||||
return c11
|
||||
|
||||
def get_out_channels(self):
|
||||
return self.outchannel
|
||||
|
||||
|
||||
def darknet53():
|
||||
"""
|
||||
Get DarkNet53 neural network.
|
||||
|
||||
Returns:
|
||||
Cell, cell instance of DarkNet53 neural network.
|
||||
|
||||
Examples:
|
||||
darknet53()
|
||||
"""
|
||||
return DarkNet(ResidualBlock, [1, 2, 8, 8, 4],
|
||||
[32, 64, 128, 256, 512],
|
||||
[64, 128, 256, 512, 1024])
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Yolo dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
@ -0,0 +1,179 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parameter init."""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.common.initializer import Initializer as MeInitializer
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
np.random.seed(5)
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function (`nn.functional` name)
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of 'num' and 'arr'."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
|
||||
def _calculate_correct_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
uniform distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Also known as He initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `Tensor`
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
fan = _calculate_correct_fan(arr, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(arr):
|
||||
"""Calculate fan in and fan out."""
|
||||
dimensions = len(arr.shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = arr.shape[1]
|
||||
num_output_fmaps = arr.shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = arr[0][0].size
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
class KaimingUniform(MeInitializer):
|
||||
"""Kaiming uniform initializer."""
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingUniform, self).__init__()
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
|
||||
_assignment(arr, tmp)
|
||||
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
"""Initialize parameter."""
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape,
|
||||
cell.weight.default_input.dtype).to_tensor()
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
|
||||
cell.bias.default_input.dtype)
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.default_input.shape,
|
||||
cell.weight.default_input.dtype).to_tensor()
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
|
||||
cell.bias.default_input.dtype)
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
@ -0,0 +1,80 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
Logger.
|
||||
|
||||
Args:
|
||||
logger_name: String. Logger name.
|
||||
rank: Integer. Rank id.
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
"""Get Logger."""
|
||||
logger = LOGGER('yolov3_darknet53', rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
@ -0,0 +1,70 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""YOLOV3 loss."""
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class XYLoss(nn.Cell):
|
||||
"""Loss for x and y."""
|
||||
def __init__(self):
|
||||
super(XYLoss, self).__init__()
|
||||
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, box_loss_scale, predict_xy, true_xy):
|
||||
xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy)
|
||||
xy_loss = self.reduce_sum(xy_loss, ())
|
||||
return xy_loss
|
||||
|
||||
|
||||
class WHLoss(nn.Cell):
|
||||
"""Loss for w and h."""
|
||||
def __init__(self):
|
||||
super(WHLoss, self).__init__()
|
||||
self.square = P.Square()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, box_loss_scale, predict_wh, true_wh):
|
||||
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh)
|
||||
wh_loss = self.reduce_sum(wh_loss, ())
|
||||
return wh_loss
|
||||
|
||||
|
||||
class ConfidenceLoss(nn.Cell):
|
||||
"""Loss for confidence."""
|
||||
def __init__(self):
|
||||
super(ConfidenceLoss, self).__init__()
|
||||
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, predict_confidence, ignore_mask):
|
||||
confidence_loss = self.cross_entropy(predict_confidence, object_mask)
|
||||
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
|
||||
confidence_loss = self.reduce_sum(confidence_loss, ())
|
||||
return confidence_loss
|
||||
|
||||
|
||||
class ClassLoss(nn.Cell):
|
||||
"""Loss for classification."""
|
||||
def __init__(self):
|
||||
super(ClassLoss, self).__init__()
|
||||
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, predict_class, class_probs):
|
||||
class_loss = object_mask * self.cross_entropy(predict_class, class_probs)
|
||||
class_loss = self.reduce_sum(class_loss, ())
|
||||
return class_loss
|
@ -0,0 +1,143 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate scheduler."""
|
||||
import math
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
"""Linear learning rate."""
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""Warmup step learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""Cosine annealing learning rate."""
|
||||
base_lr = lr
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = 0
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""Cosine annealing learning rate V2."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
last_lr = 0
|
||||
last_epoch_V1 = 0
|
||||
|
||||
T_max_V2 = int(max_epoch*1/3)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
if i < total_steps*2/3:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||
last_lr = lr
|
||||
last_epoch_V1 = last_epoch
|
||||
else:
|
||||
base_lr = last_lr
|
||||
last_epoch = last_epoch-last_epoch_V1
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2
|
||||
|
||||
lr_each_step.append(lr)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
"""Warmup cosine annealing learning rate."""
|
||||
start_sample_epoch = 60
|
||||
step_sample = 2
|
||||
tobe_sampled_epoch = 60
|
||||
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
|
||||
max_sampled_epoch = max_epoch+tobe_sampled_epoch
|
||||
T_max = max_sampled_epoch
|
||||
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
|
||||
for i in range(total_sampled_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
|
||||
continue
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
assert total_steps == len(lr_each_step)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,177 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Util class or function."""
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
def load_backbone(net, ckpt_path, args):
|
||||
"""Load darknet53 backbone checkpoint."""
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
yolo_backbone_prefix = 'feature_map.backbone'
|
||||
darknet_backbone_prefix = 'network.backbone'
|
||||
find_param = []
|
||||
not_found_param = []
|
||||
|
||||
for name, cell in net.cells_and_names():
|
||||
if name.startswith(yolo_backbone_prefix):
|
||||
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
|
||||
if isinstance(cell, (nn.Conv2d, nn.Dense)):
|
||||
darknet_weight = '{}.weight'.format(name)
|
||||
darknet_bias = '{}.bias'.format(name)
|
||||
if darknet_weight in param_dict:
|
||||
cell.weight.default_input = param_dict[darknet_weight].data
|
||||
find_param.append(darknet_weight)
|
||||
else:
|
||||
not_found_param.append(darknet_weight)
|
||||
if darknet_bias in param_dict:
|
||||
cell.bias.default_input = param_dict[darknet_bias].data
|
||||
find_param.append(darknet_bias)
|
||||
else:
|
||||
not_found_param.append(darknet_bias)
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
darknet_moving_mean = '{}.moving_mean'.format(name)
|
||||
darknet_moving_variance = '{}.moving_variance'.format(name)
|
||||
darknet_gamma = '{}.gamma'.format(name)
|
||||
darknet_beta = '{}.beta'.format(name)
|
||||
if darknet_moving_mean in param_dict:
|
||||
cell.moving_mean.default_input = param_dict[darknet_moving_mean].data
|
||||
find_param.append(darknet_moving_mean)
|
||||
else:
|
||||
not_found_param.append(darknet_moving_mean)
|
||||
if darknet_moving_variance in param_dict:
|
||||
cell.moving_variance.default_input = param_dict[darknet_moving_variance].data
|
||||
find_param.append(darknet_moving_variance)
|
||||
else:
|
||||
not_found_param.append(darknet_moving_variance)
|
||||
if darknet_gamma in param_dict:
|
||||
cell.gamma.default_input = param_dict[darknet_gamma].data
|
||||
find_param.append(darknet_gamma)
|
||||
else:
|
||||
not_found_param.append(darknet_gamma)
|
||||
if darknet_beta in param_dict:
|
||||
cell.beta.default_input = param_dict[darknet_beta].data
|
||||
find_param.append(darknet_beta)
|
||||
else:
|
||||
not_found_param.append(darknet_beta)
|
||||
|
||||
args.logger.info('================found_param {}========='.format(len(find_param)))
|
||||
args.logger.info(find_param)
|
||||
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
|
||||
args.logger.info(not_found_param)
|
||||
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def default_wd_filter(x):
|
||||
"""default weight decay filter."""
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
return False
|
||||
if parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
return False
|
||||
if parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
"""Param groups for optimizer."""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
class ShapeRecord:
|
||||
"""Log image shape."""
|
||||
def __init__(self):
|
||||
self.shape_record = {
|
||||
320: 0,
|
||||
352: 0,
|
||||
384: 0,
|
||||
416: 0,
|
||||
448: 0,
|
||||
480: 0,
|
||||
512: 0,
|
||||
544: 0,
|
||||
576: 0,
|
||||
608: 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
def set(self, shape):
|
||||
if len(shape) > 1:
|
||||
shape = shape[0]
|
||||
shape = int(shape)
|
||||
self.shape_record[shape] += 1
|
||||
self.shape_record['total'] += 1
|
||||
|
||||
def show(self, logger):
|
||||
for key in self.shape_record:
|
||||
rate = self.shape_record[key] / float(self.shape_record['total'])
|
||||
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,184 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""YOLOV3 dataset."""
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
from pycocotools.coco import COCO
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.transforms.vision.c_transforms as CV
|
||||
|
||||
from src.distributed_sampler import DistributedSampler
|
||||
from src.transforms import reshape_fn, MultiScaleTrans
|
||||
|
||||
|
||||
min_keypoints_per_image = 10
|
||||
|
||||
|
||||
def _has_only_empty_bbox(anno):
|
||||
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
||||
|
||||
|
||||
def _count_visible_keypoints(anno):
|
||||
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
||||
|
||||
|
||||
def has_valid_annotation(anno):
|
||||
"""Check annotation file."""
|
||||
# if it's empty, there is no annotation
|
||||
if not anno:
|
||||
return False
|
||||
# if all boxes have close to zero area, there is no annotation
|
||||
if _has_only_empty_bbox(anno):
|
||||
return False
|
||||
# keypoints task have a slight different critera for considering
|
||||
# if an annotation is valid
|
||||
if "keypoints" not in anno[0]:
|
||||
return True
|
||||
# for keypoint detection tasks, only consider valid images those
|
||||
# containing at least min_keypoints_per_image
|
||||
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class COCOYoloDataset:
|
||||
"""YOLOV3 Dataset for COCO."""
|
||||
def __init__(self, root, ann_file, remove_images_without_annotations=True,
|
||||
filter_crowd_anno=True, is_training=True):
|
||||
self.coco = COCO(ann_file)
|
||||
self.root = root
|
||||
self.img_ids = list(sorted(self.coco.imgs.keys()))
|
||||
self.filter_crowd_anno = filter_crowd_anno
|
||||
self.is_training = is_training
|
||||
|
||||
# filter images without any annotations
|
||||
if remove_images_without_annotations:
|
||||
img_ids = []
|
||||
for img_id in self.img_ids:
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||
anno = self.coco.loadAnns(ann_ids)
|
||||
if has_valid_annotation(anno):
|
||||
img_ids.append(img_id)
|
||||
self.img_ids = img_ids
|
||||
|
||||
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
|
||||
|
||||
self.cat_ids_to_continuous_ids = {
|
||||
v: i for i, v in enumerate(self.coco.getCatIds())
|
||||
}
|
||||
self.continuous_ids_cat_ids = {
|
||||
v: k for k, v in self.cat_ids_to_continuous_ids.items()
|
||||
}
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
|
||||
generated by the image's annotation. img is a PIL image.
|
||||
"""
|
||||
coco = self.coco
|
||||
img_id = self.img_ids[index]
|
||||
img_path = coco.loadImgs(img_id)[0]["file_name"]
|
||||
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
|
||||
if not self.is_training:
|
||||
return img, img_id
|
||||
|
||||
ann_ids = coco.getAnnIds(imgIds=img_id)
|
||||
target = coco.loadAnns(ann_ids)
|
||||
# filter crowd annotations
|
||||
if self.filter_crowd_anno:
|
||||
annos = [anno for anno in target if anno["iscrowd"] == 0]
|
||||
else:
|
||||
annos = [anno for anno in target]
|
||||
|
||||
target = {}
|
||||
boxes = [anno["bbox"] for anno in annos]
|
||||
target["bboxes"] = boxes
|
||||
|
||||
classes = [anno["category_id"] for anno in annos]
|
||||
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
|
||||
target["labels"] = classes
|
||||
|
||||
bboxes = target['bboxes']
|
||||
labels = target['labels']
|
||||
out_target = []
|
||||
for bbox, label in zip(bboxes, labels):
|
||||
tmp = []
|
||||
# convert to [x_min y_min x_max y_max]
|
||||
bbox = self._convetTopDown(bbox)
|
||||
tmp.extend(bbox)
|
||||
tmp.append(int(label))
|
||||
# tmp [x_min y_min x_max y_max, label]
|
||||
out_target.append(tmp)
|
||||
return img, out_target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_ids)
|
||||
|
||||
def _convetTopDown(self, bbox):
|
||||
x_min = bbox[0]
|
||||
y_min = bbox[1]
|
||||
w = bbox[2]
|
||||
h = bbox[3]
|
||||
return [x_min, y_min, x_min+w, y_min+h]
|
||||
|
||||
|
||||
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
|
||||
config=None, is_training=True, shuffle=True):
|
||||
"""Create dataset for YOLOV3."""
|
||||
if is_training:
|
||||
filter_crowd = True
|
||||
remove_empty_anno = True
|
||||
else:
|
||||
filter_crowd = False
|
||||
remove_empty_anno = False
|
||||
|
||||
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
|
||||
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
|
||||
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
|
||||
hwc_to_chw = CV.HWC2CHW()
|
||||
|
||||
config.dataset_size = len(yolo_dataset)
|
||||
num_parallel_workers1 = int(64 / device_num)
|
||||
num_parallel_workers2 = int(16 / device_num)
|
||||
if is_training:
|
||||
multi_scale_trans = MultiScaleTrans(config, device_num)
|
||||
if device_num != 8:
|
||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"],
|
||||
num_parallel_workers=num_parallel_workers1,
|
||||
sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
|
||||
num_parallel_workers=num_parallel_workers2, drop_remainder=True)
|
||||
else:
|
||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
|
||||
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
|
||||
num_parallel_workers=8, drop_remainder=True)
|
||||
else:
|
||||
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
|
||||
sampler=distributed_sampler)
|
||||
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
|
||||
ds = ds.map(input_columns=["image", "img_id"],
|
||||
output_columns=["image", "image_shape", "img_id"],
|
||||
columns_order=["image", "image_shape", "img_id"],
|
||||
operations=compose_map_func, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(max_epoch)
|
||||
|
||||
return ds, len(yolo_dataset)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue