!2726 Add YOLOV3-DarkNet53 to Model Zoo
Merge pull request !2726 from yangyongjie/masterpull/2726/MERGE
commit
cf731d36a0
@ -0,0 +1,132 @@
|
|||||||
|
# YOLOV3-DarkNet53 Example
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This is an example of training YOLOV3-DarkNet53 with COCO2014 dataset in MindSpore.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||||
|
|
||||||
|
- Download the dataset COCO2014.
|
||||||
|
|
||||||
|
> Unzip the COCO2014 dataset to any path you want, the folder should include train and eval dataset as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
└─dataset
|
||||||
|
├─train2014
|
||||||
|
├─val2014
|
||||||
|
└─annotations
|
||||||
|
```
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└─yolov3_darknet53
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
├─run_standalone_train.sh # launch standalone training(1p)
|
||||||
|
├─run_distribute_train.sh # launch distributed training(8p)
|
||||||
|
└─run_eval.sh # launch evaluating
|
||||||
|
├─src
|
||||||
|
├─config.py # parameter configuration
|
||||||
|
├─darknet.py # backbone of network
|
||||||
|
├─distributed_sampler.py # iterator of dataset
|
||||||
|
├─initializer.py # initializer of parameters
|
||||||
|
├─logger.py # log function
|
||||||
|
├─loss.py # loss function
|
||||||
|
├─lr_scheduler.py # generate learning rate
|
||||||
|
├─transforms.py # Preprocess data
|
||||||
|
├─util.py # util function
|
||||||
|
├─yolo.py # yolov3 network
|
||||||
|
├─yolo_dataset.py # create dataset for YOLOV3
|
||||||
|
├─eval.py # eval net
|
||||||
|
└─train.py # train net
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
### Train
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
# distributed training
|
||||||
|
sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [MINDSPORE_HCCL_CONFIG_PATH]
|
||||||
|
|
||||||
|
# standalone training
|
||||||
|
sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# distributed training example(8p)
|
||||||
|
sh run_distribute_train.sh dataset/coco2014 backbone/backbone.ckpt rank_table_8p.json
|
||||||
|
|
||||||
|
# standalone training example(1p)
|
||||||
|
sh run_standalone_train.sh dataset/coco2014 backbone/backbone.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||||
|
|
||||||
|
#### Result
|
||||||
|
|
||||||
|
Training result will be stored in the scripts path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in log.txt.
|
||||||
|
|
||||||
|
```
|
||||||
|
# distribute training result(8p)
|
||||||
|
epoch[0], iter[0], loss:14623.384766, 1.23 imgs/sec, lr:7.812499825377017e-05
|
||||||
|
epoch[0], iter[100], loss:1486.253051, 15.01 imgs/sec, lr:0.007890624925494194
|
||||||
|
epoch[0], iter[200], loss:288.579535, 490.41 imgs/sec, lr:0.015703124925494194
|
||||||
|
epoch[0], iter[300], loss:153.136754, 531.99 imgs/sec, lr:0.023515624925494194
|
||||||
|
epoch[1], iter[400], loss:106.429322, 405.14 imgs/sec, lr:0.03132812678813934
|
||||||
|
...
|
||||||
|
epoch[318], iter[102000], loss:34.135306, 431.06 imgs/sec, lr:9.63797629083274e-06
|
||||||
|
epoch[319], iter[102100], loss:35.652469, 449.52 imgs/sec, lr:2.409552052995423e-06
|
||||||
|
epoch[319], iter[102200], loss:34.652273, 384.02 imgs/sec, lr:2.409552052995423e-06
|
||||||
|
epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e-06
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Infer
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
# infer
|
||||||
|
sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# infer with checkpoint
|
||||||
|
sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
> checkpoint can be produced in training process.
|
||||||
|
|
||||||
|
|
||||||
|
#### Result
|
||||||
|
|
||||||
|
Inference result will be stored in the scripts path, whose folder name is "eval". Under this, you can find result like the followings in log.txt.
|
||||||
|
|
||||||
|
```
|
||||||
|
=============coco eval reulst=========
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311
|
||||||
|
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.528
|
||||||
|
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.322
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.127
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.323
|
||||||
|
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.428
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.259
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.398
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.423
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.224
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.442
|
||||||
|
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551
|
||||||
|
```
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,81 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 3 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [MINDSPORE_HCCL_CONFIG_PATH]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_PATH=$(get_real_path $1)
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||||
|
MINDSPORE_HCCL_CONFIG_PATH=$(get_real_path $3)
|
||||||
|
echo $DATASET_PATH
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
echo $MINDSPORE_HCCL_CONFIG_PATH
|
||||||
|
|
||||||
|
if [ ! -d $DATASET_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $MINDSPORE_HCCL_CONFIG_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=8
|
||||||
|
export RANK_SIZE=8
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PATH=$MINDSPORE_HCCL_CONFIG_PATH
|
||||||
|
|
||||||
|
for((i=0; i<${DEVICE_NUM}; i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
rm -rf ./train_parallel$i
|
||||||
|
mkdir ./train_parallel$i
|
||||||
|
cp ../*.py ./train_parallel$i
|
||||||
|
cp -r ../src ./train_parallel$i
|
||||||
|
cd ./train_parallel$i || exit
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python train.py \
|
||||||
|
--data_dir=$DATASET_PATH \
|
||||||
|
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||||
|
--is_distributed=1 \
|
||||||
|
--lr=0.1 \
|
||||||
|
--T_max=320 \
|
||||||
|
--max_epoch=320 \
|
||||||
|
--warmup_epochs=4 \
|
||||||
|
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||||
|
cd ..
|
||||||
|
done
|
@ -0,0 +1,66 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
DATASET_PATH=$(get_real_path $1)
|
||||||
|
CHECKPOINT_PATH=$(get_real_path $2)
|
||||||
|
echo $DATASET_PATH
|
||||||
|
echo $CHECKPOINT_PATH
|
||||||
|
|
||||||
|
if [ ! -d $DATASET_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $CHECKPOINT_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_SIZE=$DEVICE_NUM
|
||||||
|
export RANK_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
env > env.log
|
||||||
|
echo "start infering for device $DEVICE_ID"
|
||||||
|
python eval.py \
|
||||||
|
--data_dir=$DATASET_PATH \
|
||||||
|
--pretrained=$CHECKPOINT_PATH \
|
||||||
|
--testing_shape=416 > log.txt 2>&1 &
|
||||||
|
cd ..
|
@ -0,0 +1,73 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# != 2 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_PATH=$(get_real_path $1)
|
||||||
|
echo $DATASET_PATH
|
||||||
|
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||||
|
echo $PRETRAINED_BACKBONE
|
||||||
|
|
||||||
|
if [ ! -d $DATASET_PATH ]
|
||||||
|
then
|
||||||
|
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||||
|
then
|
||||||
|
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cd ./train || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
|
||||||
|
python train.py \
|
||||||
|
--data_dir=$DATASET_PATH \
|
||||||
|
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||||
|
--is_distributed=0 \
|
||||||
|
--lr=0.1 \
|
||||||
|
--T_max=320 \
|
||||||
|
--max_epoch=320 \
|
||||||
|
--warmup_epochs=4 \
|
||||||
|
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||||
|
cd ..
|
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Config parameters for Darknet based yolov3_darknet53 models."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigYOLOV3DarkNet53:
|
||||||
|
"""
|
||||||
|
Config parameters for the yolov3_darknet53.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
ConfigYOLOV3DarkNet53()
|
||||||
|
"""
|
||||||
|
# train_param
|
||||||
|
# data augmentation related
|
||||||
|
hue = 0.1
|
||||||
|
saturation = 1.5
|
||||||
|
value = 1.5
|
||||||
|
jitter = 0.3
|
||||||
|
|
||||||
|
resize_rate = 1
|
||||||
|
multi_scale = [[320, 320],
|
||||||
|
[352, 352],
|
||||||
|
[384, 384],
|
||||||
|
[416, 416],
|
||||||
|
[448, 448],
|
||||||
|
[480, 480],
|
||||||
|
[512, 512],
|
||||||
|
[544, 544],
|
||||||
|
[576, 576],
|
||||||
|
[608, 608]
|
||||||
|
]
|
||||||
|
|
||||||
|
num_classes = 80
|
||||||
|
max_box = 50
|
||||||
|
|
||||||
|
backbone_input_shape = [32, 64, 128, 256, 512]
|
||||||
|
backbone_shape = [64, 128, 256, 512, 1024]
|
||||||
|
backbone_layers = [1, 2, 8, 8, 4]
|
||||||
|
|
||||||
|
# confidence under ignore_threshold means no object when training
|
||||||
|
ignore_threshold = 0.7
|
||||||
|
|
||||||
|
# h->w
|
||||||
|
anchor_scales = [(10, 13),
|
||||||
|
(16, 30),
|
||||||
|
(33, 23),
|
||||||
|
(30, 61),
|
||||||
|
(62, 45),
|
||||||
|
(59, 119),
|
||||||
|
(116, 90),
|
||||||
|
(156, 198),
|
||||||
|
(373, 326)]
|
||||||
|
out_channel = 255
|
||||||
|
|
||||||
|
# test_param
|
||||||
|
test_img_shape = [416, 416]
|
@ -0,0 +1,211 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""DarkNet model."""
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
def conv_block(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
dilation=1):
|
||||||
|
"""Get a conv2d batchnorm and relu layer"""
|
||||||
|
pad_mode = 'same'
|
||||||
|
padding = 0
|
||||||
|
|
||||||
|
return nn.SequentialCell(
|
||||||
|
[nn.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
pad_mode=pad_mode),
|
||||||
|
nn.BatchNorm2d(out_channels, momentum=0.1),
|
||||||
|
nn.ReLU()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Cell):
|
||||||
|
"""
|
||||||
|
DarkNet V1 residual block definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Integer. Input channel.
|
||||||
|
out_channels: Integer. Output channel.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, output tensor.
|
||||||
|
Examples:
|
||||||
|
ResidualBlock(3, 208)
|
||||||
|
"""
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels):
|
||||||
|
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
out_chls = out_channels//2
|
||||||
|
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||||
|
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
identity = x
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.add(out, identity)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DarkNet(nn.Cell):
|
||||||
|
"""
|
||||||
|
DarkNet V1 network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: Cell. Block for network.
|
||||||
|
layer_nums: List. Numbers of different layers.
|
||||||
|
in_channels: Integer. Input channel.
|
||||||
|
out_channels: Integer. Output channel.
|
||||||
|
detect: Bool. Whether detect or not. Default:False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
DarkNet(ResidualBlock,
|
||||||
|
[1, 2, 8, 8, 4],
|
||||||
|
[32, 64, 128, 256, 512],
|
||||||
|
[64, 128, 256, 512, 1024],
|
||||||
|
100)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
block,
|
||||||
|
layer_nums,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
detect=False):
|
||||||
|
super(DarkNet, self).__init__()
|
||||||
|
|
||||||
|
self.outchannel = out_channels[-1]
|
||||||
|
self.detect = detect
|
||||||
|
|
||||||
|
if not len(layer_nums) == len(in_channels) == len(out_channels) == 5:
|
||||||
|
raise ValueError("the length of layer_num, inchannel, outchannel list must be 5!")
|
||||||
|
self.conv0 = conv_block(3,
|
||||||
|
in_channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1)
|
||||||
|
self.conv1 = conv_block(in_channels[0],
|
||||||
|
out_channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2)
|
||||||
|
self.conv2 = conv_block(in_channels[1],
|
||||||
|
out_channels[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2)
|
||||||
|
self.conv3 = conv_block(in_channels[2],
|
||||||
|
out_channels[2],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2)
|
||||||
|
self.conv4 = conv_block(in_channels[3],
|
||||||
|
out_channels[3],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2)
|
||||||
|
self.conv5 = conv_block(in_channels[4],
|
||||||
|
out_channels[4],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(block,
|
||||||
|
layer_nums[0],
|
||||||
|
in_channel=out_channels[0],
|
||||||
|
out_channel=out_channels[0])
|
||||||
|
self.layer2 = self._make_layer(block,
|
||||||
|
layer_nums[1],
|
||||||
|
in_channel=out_channels[1],
|
||||||
|
out_channel=out_channels[1])
|
||||||
|
self.layer3 = self._make_layer(block,
|
||||||
|
layer_nums[2],
|
||||||
|
in_channel=out_channels[2],
|
||||||
|
out_channel=out_channels[2])
|
||||||
|
self.layer4 = self._make_layer(block,
|
||||||
|
layer_nums[3],
|
||||||
|
in_channel=out_channels[3],
|
||||||
|
out_channel=out_channels[3])
|
||||||
|
self.layer5 = self._make_layer(block,
|
||||||
|
layer_nums[4],
|
||||||
|
in_channel=out_channels[4],
|
||||||
|
out_channel=out_channels[4])
|
||||||
|
|
||||||
|
def _make_layer(self, block, layer_num, in_channel, out_channel):
|
||||||
|
"""
|
||||||
|
Make Layer for DarkNet.
|
||||||
|
|
||||||
|
:param block: Cell. DarkNet block.
|
||||||
|
:param layer_num: Integer. Layer number.
|
||||||
|
:param in_channel: Integer. Input channel.
|
||||||
|
:param out_channel: Integer. Output channel.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
_make_layer(ConvBlock, 1, 128, 256)
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
darkblk = block(in_channel, out_channel)
|
||||||
|
layers.append(darkblk)
|
||||||
|
|
||||||
|
for _ in range(1, layer_num):
|
||||||
|
darkblk = block(out_channel, out_channel)
|
||||||
|
layers.append(darkblk)
|
||||||
|
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
c1 = self.conv0(x)
|
||||||
|
c2 = self.conv1(c1)
|
||||||
|
c3 = self.layer1(c2)
|
||||||
|
c4 = self.conv2(c3)
|
||||||
|
c5 = self.layer2(c4)
|
||||||
|
c6 = self.conv3(c5)
|
||||||
|
c7 = self.layer3(c6)
|
||||||
|
c8 = self.conv4(c7)
|
||||||
|
c9 = self.layer4(c8)
|
||||||
|
c10 = self.conv5(c9)
|
||||||
|
c11 = self.layer5(c10)
|
||||||
|
if self.detect:
|
||||||
|
return c7, c9, c11
|
||||||
|
|
||||||
|
return c11
|
||||||
|
|
||||||
|
def get_out_channels(self):
|
||||||
|
return self.outchannel
|
||||||
|
|
||||||
|
|
||||||
|
def darknet53():
|
||||||
|
"""
|
||||||
|
Get DarkNet53 neural network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cell, cell instance of DarkNet53 neural network.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
darknet53()
|
||||||
|
"""
|
||||||
|
return DarkNet(ResidualBlock, [1, 2, 8, 8, 4],
|
||||||
|
[32, 64, 128, 256, 512],
|
||||||
|
[64, 128, 256, 512, 1024])
|
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Yolo dataset distributed sampler."""
|
||||||
|
from __future__ import division
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedSampler:
|
||||||
|
"""Distributed sampler."""
|
||||||
|
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||||
|
if num_replicas is None:
|
||||||
|
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||||
|
num_replicas = 1
|
||||||
|
if rank is None:
|
||||||
|
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||||
|
rank = 0
|
||||||
|
self.dataset_size = dataset_size
|
||||||
|
self.num_replicas = num_replicas
|
||||||
|
self.rank = rank
|
||||||
|
self.epoch = 0
|
||||||
|
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||||
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# deterministically shuffle based on epoch
|
||||||
|
if self.shuffle:
|
||||||
|
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||||
|
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||||
|
indices = indices.tolist()
|
||||||
|
self.epoch += 1
|
||||||
|
# change to list type
|
||||||
|
else:
|
||||||
|
indices = list(range(self.dataset_size))
|
||||||
|
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
indices += indices[:(self.total_size - len(indices))]
|
||||||
|
assert len(indices) == self.total_size
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||||
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
@ -0,0 +1,179 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Parameter init."""
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from mindspore.common import initializer as init
|
||||||
|
from mindspore.common.initializer import Initializer as MeInitializer
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
np.random.seed(5)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_gain(nonlinearity, param=None):
|
||||||
|
r"""Return the recommended gain value for the given nonlinearity function.
|
||||||
|
The values are as follows:
|
||||||
|
|
||||||
|
================= ====================================================
|
||||||
|
nonlinearity gain
|
||||||
|
================= ====================================================
|
||||||
|
Linear / Identity :math:`1`
|
||||||
|
Conv{1,2,3}D :math:`1`
|
||||||
|
Sigmoid :math:`1`
|
||||||
|
Tanh :math:`\frac{5}{3}`
|
||||||
|
ReLU :math:`\sqrt{2}`
|
||||||
|
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||||
|
================= ====================================================
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nonlinearity: the non-linear function (`nn.functional` name)
|
||||||
|
param: optional parameter for the non-linear function
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||||
|
"""
|
||||||
|
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||||
|
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||||
|
return 1
|
||||||
|
if nonlinearity == 'tanh':
|
||||||
|
return 5.0 / 3
|
||||||
|
if nonlinearity == 'relu':
|
||||||
|
return math.sqrt(2.0)
|
||||||
|
if nonlinearity == 'leaky_relu':
|
||||||
|
if param is None:
|
||||||
|
negative_slope = 0.01
|
||||||
|
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||||
|
# True/False are instances of int, hence check above
|
||||||
|
negative_slope = param
|
||||||
|
else:
|
||||||
|
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||||
|
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||||
|
|
||||||
|
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||||
|
|
||||||
|
|
||||||
|
def _assignment(arr, num):
|
||||||
|
"""Assign the value of 'num' and 'arr'."""
|
||||||
|
if arr.shape == ():
|
||||||
|
arr = arr.reshape((1))
|
||||||
|
arr[:] = num
|
||||||
|
arr = arr.reshape(())
|
||||||
|
else:
|
||||||
|
if isinstance(num, np.ndarray):
|
||||||
|
arr[:] = num[:]
|
||||||
|
else:
|
||||||
|
arr[:] = num
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_correct_fan(array, mode):
|
||||||
|
mode = mode.lower()
|
||||||
|
valid_modes = ['fan_in', 'fan_out']
|
||||||
|
if mode not in valid_modes:
|
||||||
|
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||||
|
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
|
||||||
|
return fan_in if mode == 'fan_in' else fan_out
|
||||||
|
|
||||||
|
|
||||||
|
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||||
|
r"""Fills the input `Tensor` with values according to the method
|
||||||
|
described in `Delving deep into rectifiers: Surpassing human-level
|
||||||
|
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||||
|
uniform distribution. The resulting tensor will have values sampled from
|
||||||
|
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||||
|
|
||||||
|
Also known as He initialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: an n-dimensional `Tensor`
|
||||||
|
a: the negative slope of the rectifier used after this layer (only
|
||||||
|
used with ``'leaky_relu'``)
|
||||||
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||||
|
preserves the magnitude of the variance of the weights in the
|
||||||
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||||
|
backwards pass.
|
||||||
|
nonlinearity: the non-linear function (`nn.functional` name),
|
||||||
|
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> w = np.empty(3, 5)
|
||||||
|
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||||
|
"""
|
||||||
|
fan = _calculate_correct_fan(arr, mode)
|
||||||
|
gain = calculate_gain(nonlinearity, a)
|
||||||
|
std = gain / math.sqrt(fan)
|
||||||
|
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||||
|
return np.random.uniform(-bound, bound, arr.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_fan_in_and_fan_out(arr):
|
||||||
|
"""Calculate fan in and fan out."""
|
||||||
|
dimensions = len(arr.shape)
|
||||||
|
if dimensions < 2:
|
||||||
|
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
|
||||||
|
|
||||||
|
num_input_fmaps = arr.shape[1]
|
||||||
|
num_output_fmaps = arr.shape[0]
|
||||||
|
receptive_field_size = 1
|
||||||
|
if dimensions > 2:
|
||||||
|
receptive_field_size = arr[0][0].size
|
||||||
|
fan_in = num_input_fmaps * receptive_field_size
|
||||||
|
fan_out = num_output_fmaps * receptive_field_size
|
||||||
|
|
||||||
|
return fan_in, fan_out
|
||||||
|
|
||||||
|
|
||||||
|
class KaimingUniform(MeInitializer):
|
||||||
|
"""Kaiming uniform initializer."""
|
||||||
|
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||||
|
super(KaimingUniform, self).__init__()
|
||||||
|
self.a = a
|
||||||
|
self.mode = mode
|
||||||
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
|
def _initialize(self, arr):
|
||||||
|
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
|
||||||
|
_assignment(arr, tmp)
|
||||||
|
|
||||||
|
|
||||||
|
def default_recurisive_init(custom_cell):
|
||||||
|
"""Initialize parameter."""
|
||||||
|
for _, cell in custom_cell.cells_and_names():
|
||||||
|
if isinstance(cell, nn.Conv2d):
|
||||||
|
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||||
|
cell.weight.default_input.shape,
|
||||||
|
cell.weight.default_input.dtype).to_tensor()
|
||||||
|
if cell.bias is not None:
|
||||||
|
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
|
||||||
|
cell.bias.default_input.dtype)
|
||||||
|
elif isinstance(cell, nn.Dense):
|
||||||
|
cell.weight.default_input = init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||||
|
cell.weight.default_input.shape,
|
||||||
|
cell.weight.default_input.dtype).to_tensor()
|
||||||
|
if cell.bias is not None:
|
||||||
|
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.default_input.asnumpy())
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
cell.bias.default_input = Tensor(np.random.uniform(-bound, bound, cell.bias.default_input.shape),
|
||||||
|
cell.bias.default_input.dtype)
|
||||||
|
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||||
|
pass
|
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Custom Logger."""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class LOGGER(logging.Logger):
|
||||||
|
"""
|
||||||
|
Logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger_name: String. Logger name.
|
||||||
|
rank: Integer. Rank id.
|
||||||
|
"""
|
||||||
|
def __init__(self, logger_name, rank=0):
|
||||||
|
super(LOGGER, self).__init__(logger_name)
|
||||||
|
self.rank = rank
|
||||||
|
if rank % 8 == 0:
|
||||||
|
console = logging.StreamHandler(sys.stdout)
|
||||||
|
console.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
|
console.setFormatter(formatter)
|
||||||
|
self.addHandler(console)
|
||||||
|
|
||||||
|
def setup_logging_file(self, log_dir, rank=0):
|
||||||
|
"""Setup logging file."""
|
||||||
|
self.rank = rank
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||||
|
self.log_fn = os.path.join(log_dir, log_name)
|
||||||
|
fh = logging.FileHandler(self.log_fn)
|
||||||
|
fh.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||||
|
fh.setFormatter(formatter)
|
||||||
|
self.addHandler(fh)
|
||||||
|
|
||||||
|
def info(self, msg, *args, **kwargs):
|
||||||
|
if self.isEnabledFor(logging.INFO):
|
||||||
|
self._log(logging.INFO, msg, args, **kwargs)
|
||||||
|
|
||||||
|
def save_args(self, args):
|
||||||
|
self.info('Args:')
|
||||||
|
args_dict = vars(args)
|
||||||
|
for key in args_dict.keys():
|
||||||
|
self.info('--> %s: %s', key, args_dict[key])
|
||||||
|
self.info('')
|
||||||
|
|
||||||
|
def important_info(self, msg, *args, **kwargs):
|
||||||
|
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||||
|
line_width = 2
|
||||||
|
important_msg = '\n'
|
||||||
|
important_msg += ('*'*70 + '\n')*line_width
|
||||||
|
important_msg += ('*'*line_width + '\n')*2
|
||||||
|
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||||
|
important_msg += ('*'*line_width + '\n')*2
|
||||||
|
important_msg += ('*'*70 + '\n')*line_width
|
||||||
|
self.info(important_msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(path, rank):
|
||||||
|
"""Get Logger."""
|
||||||
|
logger = LOGGER('yolov3_darknet53', rank)
|
||||||
|
logger.setup_logging_file(path, rank)
|
||||||
|
return logger
|
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""YOLOV3 loss."""
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class XYLoss(nn.Cell):
|
||||||
|
"""Loss for x and y."""
|
||||||
|
def __init__(self):
|
||||||
|
super(XYLoss, self).__init__()
|
||||||
|
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, object_mask, box_loss_scale, predict_xy, true_xy):
|
||||||
|
xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy)
|
||||||
|
xy_loss = self.reduce_sum(xy_loss, ())
|
||||||
|
return xy_loss
|
||||||
|
|
||||||
|
|
||||||
|
class WHLoss(nn.Cell):
|
||||||
|
"""Loss for w and h."""
|
||||||
|
def __init__(self):
|
||||||
|
super(WHLoss, self).__init__()
|
||||||
|
self.square = P.Square()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, object_mask, box_loss_scale, predict_wh, true_wh):
|
||||||
|
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh)
|
||||||
|
wh_loss = self.reduce_sum(wh_loss, ())
|
||||||
|
return wh_loss
|
||||||
|
|
||||||
|
|
||||||
|
class ConfidenceLoss(nn.Cell):
|
||||||
|
"""Loss for confidence."""
|
||||||
|
def __init__(self):
|
||||||
|
super(ConfidenceLoss, self).__init__()
|
||||||
|
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, object_mask, predict_confidence, ignore_mask):
|
||||||
|
confidence_loss = self.cross_entropy(predict_confidence, object_mask)
|
||||||
|
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
|
||||||
|
confidence_loss = self.reduce_sum(confidence_loss, ())
|
||||||
|
return confidence_loss
|
||||||
|
|
||||||
|
|
||||||
|
class ClassLoss(nn.Cell):
|
||||||
|
"""Loss for classification."""
|
||||||
|
def __init__(self):
|
||||||
|
super(ClassLoss, self).__init__()
|
||||||
|
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
|
||||||
|
def construct(self, object_mask, predict_class, class_probs):
|
||||||
|
class_loss = object_mask * self.cross_entropy(predict_class, class_probs)
|
||||||
|
class_loss = self.reduce_sum(class_loss, ())
|
||||||
|
return class_loss
|
@ -0,0 +1,144 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Learning rate scheduler."""
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||||
|
"""Linear learning rate."""
|
||||||
|
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||||
|
lr = float(init_lr) + lr_inc * current_step
|
||||||
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||||
|
"""Warmup step learning rate."""
|
||||||
|
base_lr = lr
|
||||||
|
warmup_init_lr = 0
|
||||||
|
total_steps = int(max_epoch * steps_per_epoch)
|
||||||
|
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||||
|
milestones = lr_epochs
|
||||||
|
milestones_steps = []
|
||||||
|
for milestone in milestones:
|
||||||
|
milestones_step = milestone * steps_per_epoch
|
||||||
|
milestones_steps.append(milestones_step)
|
||||||
|
|
||||||
|
lr_each_step = []
|
||||||
|
lr = base_lr
|
||||||
|
milestones_steps_counter = Counter(milestones_steps)
|
||||||
|
for i in range(total_steps):
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||||
|
else:
|
||||||
|
lr = lr * gamma**milestones_steps_counter[i]
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
|
||||||
|
return np.array(lr_each_step).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||||
|
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||||
|
|
||||||
|
|
||||||
|
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||||
|
lr_epochs = []
|
||||||
|
for i in range(1, max_epoch):
|
||||||
|
if i % epoch_size == 0:
|
||||||
|
lr_epochs.append(i)
|
||||||
|
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||||
|
"""Cosine annealing learning rate."""
|
||||||
|
base_lr = lr
|
||||||
|
warmup_init_lr = 0
|
||||||
|
total_steps = int(max_epoch * steps_per_epoch)
|
||||||
|
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||||
|
|
||||||
|
lr_each_step = []
|
||||||
|
for i in range(total_steps):
|
||||||
|
last_epoch = i // steps_per_epoch
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||||
|
else:
|
||||||
|
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
|
||||||
|
return np.array(lr_each_step).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_cosine_annealing_lr_V2(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||||
|
"""Cosine annealing learning rate V2."""
|
||||||
|
base_lr = lr
|
||||||
|
warmup_init_lr = 0
|
||||||
|
total_steps = int(max_epoch * steps_per_epoch)
|
||||||
|
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||||
|
|
||||||
|
last_lr = 0
|
||||||
|
last_epoch_V1 = 0
|
||||||
|
|
||||||
|
T_max_V2 = int(max_epoch*1/3)
|
||||||
|
|
||||||
|
lr_each_step = []
|
||||||
|
for i in range(total_steps):
|
||||||
|
last_epoch = i // steps_per_epoch
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||||
|
else:
|
||||||
|
if i < total_steps*2/3:
|
||||||
|
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||||
|
last_lr = lr
|
||||||
|
last_epoch_V1 = last_epoch
|
||||||
|
else:
|
||||||
|
base_lr = last_lr
|
||||||
|
last_epoch = last_epoch-last_epoch_V1
|
||||||
|
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max_V2)) / 2
|
||||||
|
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
return np.array(lr_each_step).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||||
|
"""Warmup cosine annealing learning rate."""
|
||||||
|
start_sample_epoch = 60
|
||||||
|
step_sample = 2
|
||||||
|
tobe_sampled_epoch = 60
|
||||||
|
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
|
||||||
|
max_sampled_epoch = max_epoch+tobe_sampled_epoch
|
||||||
|
T_max = max_sampled_epoch
|
||||||
|
|
||||||
|
base_lr = lr
|
||||||
|
warmup_init_lr = 0
|
||||||
|
total_steps = int(max_epoch * steps_per_epoch)
|
||||||
|
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
|
||||||
|
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||||
|
|
||||||
|
lr_each_step = []
|
||||||
|
|
||||||
|
for i in range(total_sampled_steps):
|
||||||
|
last_epoch = i // steps_per_epoch
|
||||||
|
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
|
||||||
|
continue
|
||||||
|
if i < warmup_steps:
|
||||||
|
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||||
|
else:
|
||||||
|
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / T_max)) / 2
|
||||||
|
lr_each_step.append(lr)
|
||||||
|
|
||||||
|
assert total_steps == len(lr_each_step)
|
||||||
|
return np.array(lr_each_step).astype(np.float32)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,177 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Util class or function."""
|
||||||
|
from mindspore.train.serialization import load_checkpoint
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter:
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||||
|
self.name = name
|
||||||
|
self.fmt = fmt
|
||||||
|
self.reset()
|
||||||
|
self.tb_writer = tb_writer
|
||||||
|
self.cur_step = 1
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
if self.tb_writer is not None:
|
||||||
|
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||||
|
self.cur_step += 1
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||||
|
return fmtstr.format(**self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone(net, ckpt_path, args):
|
||||||
|
"""Load darknet53 backbone checkpoint."""
|
||||||
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
|
yolo_backbone_prefix = 'feature_map.backbone'
|
||||||
|
darknet_backbone_prefix = 'network.backbone'
|
||||||
|
find_param = []
|
||||||
|
not_found_param = []
|
||||||
|
|
||||||
|
for name, cell in net.cells_and_names():
|
||||||
|
if name.startswith(yolo_backbone_prefix):
|
||||||
|
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
|
||||||
|
if isinstance(cell, (nn.Conv2d, nn.Dense)):
|
||||||
|
darknet_weight = '{}.weight'.format(name)
|
||||||
|
darknet_bias = '{}.bias'.format(name)
|
||||||
|
if darknet_weight in param_dict:
|
||||||
|
cell.weight.default_input = param_dict[darknet_weight].data
|
||||||
|
find_param.append(darknet_weight)
|
||||||
|
else:
|
||||||
|
not_found_param.append(darknet_weight)
|
||||||
|
if darknet_bias in param_dict:
|
||||||
|
cell.bias.default_input = param_dict[darknet_bias].data
|
||||||
|
find_param.append(darknet_bias)
|
||||||
|
else:
|
||||||
|
not_found_param.append(darknet_bias)
|
||||||
|
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||||
|
darknet_moving_mean = '{}.moving_mean'.format(name)
|
||||||
|
darknet_moving_variance = '{}.moving_variance'.format(name)
|
||||||
|
darknet_gamma = '{}.gamma'.format(name)
|
||||||
|
darknet_beta = '{}.beta'.format(name)
|
||||||
|
if darknet_moving_mean in param_dict:
|
||||||
|
cell.moving_mean.default_input = param_dict[darknet_moving_mean].data
|
||||||
|
find_param.append(darknet_moving_mean)
|
||||||
|
else:
|
||||||
|
not_found_param.append(darknet_moving_mean)
|
||||||
|
if darknet_moving_variance in param_dict:
|
||||||
|
cell.moving_variance.default_input = param_dict[darknet_moving_variance].data
|
||||||
|
find_param.append(darknet_moving_variance)
|
||||||
|
else:
|
||||||
|
not_found_param.append(darknet_moving_variance)
|
||||||
|
if darknet_gamma in param_dict:
|
||||||
|
cell.gamma.default_input = param_dict[darknet_gamma].data
|
||||||
|
find_param.append(darknet_gamma)
|
||||||
|
else:
|
||||||
|
not_found_param.append(darknet_gamma)
|
||||||
|
if darknet_beta in param_dict:
|
||||||
|
cell.beta.default_input = param_dict[darknet_beta].data
|
||||||
|
find_param.append(darknet_beta)
|
||||||
|
else:
|
||||||
|
not_found_param.append(darknet_beta)
|
||||||
|
|
||||||
|
args.logger.info('================found_param {}========='.format(len(find_param)))
|
||||||
|
args.logger.info(find_param)
|
||||||
|
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
|
||||||
|
args.logger.info(not_found_param)
|
||||||
|
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def default_wd_filter(x):
|
||||||
|
"""default weight decay filter."""
|
||||||
|
parameter_name = x.name
|
||||||
|
if parameter_name.endswith('.bias'):
|
||||||
|
# all bias not using weight decay
|
||||||
|
return False
|
||||||
|
if parameter_name.endswith('.gamma'):
|
||||||
|
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||||
|
return False
|
||||||
|
if parameter_name.endswith('.beta'):
|
||||||
|
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_param_groups(network):
|
||||||
|
"""Param groups for optimizer."""
|
||||||
|
decay_params = []
|
||||||
|
no_decay_params = []
|
||||||
|
for x in network.trainable_params():
|
||||||
|
parameter_name = x.name
|
||||||
|
if parameter_name.endswith('.bias'):
|
||||||
|
# all bias not using weight decay
|
||||||
|
no_decay_params.append(x)
|
||||||
|
elif parameter_name.endswith('.gamma'):
|
||||||
|
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||||
|
no_decay_params.append(x)
|
||||||
|
elif parameter_name.endswith('.beta'):
|
||||||
|
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||||
|
no_decay_params.append(x)
|
||||||
|
else:
|
||||||
|
decay_params.append(x)
|
||||||
|
|
||||||
|
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeRecord:
|
||||||
|
"""Log image shape."""
|
||||||
|
def __init__(self):
|
||||||
|
self.shape_record = {
|
||||||
|
320: 0,
|
||||||
|
352: 0,
|
||||||
|
384: 0,
|
||||||
|
416: 0,
|
||||||
|
448: 0,
|
||||||
|
480: 0,
|
||||||
|
512: 0,
|
||||||
|
544: 0,
|
||||||
|
576: 0,
|
||||||
|
608: 0,
|
||||||
|
'total': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def set(self, shape):
|
||||||
|
if len(shape) > 1:
|
||||||
|
shape = shape[0]
|
||||||
|
shape = int(shape)
|
||||||
|
self.shape_record[shape] += 1
|
||||||
|
self.shape_record['total'] += 1
|
||||||
|
|
||||||
|
def show(self, logger):
|
||||||
|
for key in self.shape_record:
|
||||||
|
rate = self.shape_record[key] / float(self.shape_record['total'])
|
||||||
|
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,184 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""YOLOV3 dataset."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from pycocotools.coco import COCO
|
||||||
|
import mindspore.dataset as de
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as CV
|
||||||
|
|
||||||
|
from src.distributed_sampler import DistributedSampler
|
||||||
|
from src.transforms import reshape_fn, MultiScaleTrans
|
||||||
|
|
||||||
|
|
||||||
|
min_keypoints_per_image = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _has_only_empty_bbox(anno):
|
||||||
|
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_visible_keypoints(anno):
|
||||||
|
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
||||||
|
|
||||||
|
|
||||||
|
def has_valid_annotation(anno):
|
||||||
|
"""Check annotation file."""
|
||||||
|
# if it's empty, there is no annotation
|
||||||
|
if not anno:
|
||||||
|
return False
|
||||||
|
# if all boxes have close to zero area, there is no annotation
|
||||||
|
if _has_only_empty_bbox(anno):
|
||||||
|
return False
|
||||||
|
# keypoints task have a slight different critera for considering
|
||||||
|
# if an annotation is valid
|
||||||
|
if "keypoints" not in anno[0]:
|
||||||
|
return True
|
||||||
|
# for keypoint detection tasks, only consider valid images those
|
||||||
|
# containing at least min_keypoints_per_image
|
||||||
|
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class COCOYoloDataset:
|
||||||
|
"""YOLOV3 Dataset for COCO."""
|
||||||
|
def __init__(self, root, ann_file, remove_images_without_annotations=True,
|
||||||
|
filter_crowd_anno=True, is_training=True):
|
||||||
|
self.coco = COCO(ann_file)
|
||||||
|
self.root = root
|
||||||
|
self.img_ids = list(sorted(self.coco.imgs.keys()))
|
||||||
|
self.filter_crowd_anno = filter_crowd_anno
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
# filter images without any annotations
|
||||||
|
if remove_images_without_annotations:
|
||||||
|
img_ids = []
|
||||||
|
for img_id in self.img_ids:
|
||||||
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
||||||
|
anno = self.coco.loadAnns(ann_ids)
|
||||||
|
if has_valid_annotation(anno):
|
||||||
|
img_ids.append(img_id)
|
||||||
|
self.img_ids = img_ids
|
||||||
|
|
||||||
|
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
|
||||||
|
|
||||||
|
self.cat_ids_to_continuous_ids = {
|
||||||
|
v: i for i, v in enumerate(self.coco.getCatIds())
|
||||||
|
}
|
||||||
|
self.continuous_ids_cat_ids = {
|
||||||
|
v: k for k, v in self.cat_ids_to_continuous_ids.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index (int): Index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
|
||||||
|
generated by the image's annotation. img is a PIL image.
|
||||||
|
"""
|
||||||
|
coco = self.coco
|
||||||
|
img_id = self.img_ids[index]
|
||||||
|
img_path = coco.loadImgs(img_id)[0]["file_name"]
|
||||||
|
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
|
||||||
|
if not self.is_training:
|
||||||
|
return img, img_id
|
||||||
|
|
||||||
|
ann_ids = coco.getAnnIds(imgIds=img_id)
|
||||||
|
target = coco.loadAnns(ann_ids)
|
||||||
|
# filter crowd annotations
|
||||||
|
if self.filter_crowd_anno:
|
||||||
|
annos = [anno for anno in target if anno["iscrowd"] == 0]
|
||||||
|
else:
|
||||||
|
annos = [anno for anno in target]
|
||||||
|
|
||||||
|
target = {}
|
||||||
|
boxes = [anno["bbox"] for anno in annos]
|
||||||
|
target["bboxes"] = boxes
|
||||||
|
|
||||||
|
classes = [anno["category_id"] for anno in annos]
|
||||||
|
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
|
||||||
|
target["labels"] = classes
|
||||||
|
|
||||||
|
bboxes = target['bboxes']
|
||||||
|
labels = target['labels']
|
||||||
|
out_target = []
|
||||||
|
for bbox, label in zip(bboxes, labels):
|
||||||
|
tmp = []
|
||||||
|
# convert to [x_min y_min x_max y_max]
|
||||||
|
bbox = self._convetTopDown(bbox)
|
||||||
|
tmp.extend(bbox)
|
||||||
|
tmp.append(int(label))
|
||||||
|
# tmp [x_min y_min x_max y_max, label]
|
||||||
|
out_target.append(tmp)
|
||||||
|
return img, out_target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.img_ids)
|
||||||
|
|
||||||
|
def _convetTopDown(self, bbox):
|
||||||
|
x_min = bbox[0]
|
||||||
|
y_min = bbox[1]
|
||||||
|
w = bbox[2]
|
||||||
|
h = bbox[3]
|
||||||
|
return [x_min, y_min, x_min+w, y_min+h]
|
||||||
|
|
||||||
|
|
||||||
|
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
|
||||||
|
config=None, is_training=True, shuffle=True):
|
||||||
|
"""Create dataset for YOLOV3."""
|
||||||
|
if is_training:
|
||||||
|
filter_crowd = True
|
||||||
|
remove_empty_anno = True
|
||||||
|
else:
|
||||||
|
filter_crowd = False
|
||||||
|
remove_empty_anno = False
|
||||||
|
|
||||||
|
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
|
||||||
|
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
|
||||||
|
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
|
||||||
|
hwc_to_chw = CV.HWC2CHW()
|
||||||
|
|
||||||
|
config.dataset_size = len(yolo_dataset)
|
||||||
|
num_parallel_workers1 = int(64 / device_num)
|
||||||
|
num_parallel_workers2 = int(16 / device_num)
|
||||||
|
if is_training:
|
||||||
|
multi_scale_trans = MultiScaleTrans(config, device_num)
|
||||||
|
if device_num != 8:
|
||||||
|
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"],
|
||||||
|
num_parallel_workers=num_parallel_workers1,
|
||||||
|
sampler=distributed_sampler)
|
||||||
|
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
|
||||||
|
num_parallel_workers=num_parallel_workers2, drop_remainder=True)
|
||||||
|
else:
|
||||||
|
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
|
||||||
|
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'],
|
||||||
|
num_parallel_workers=8, drop_remainder=True)
|
||||||
|
else:
|
||||||
|
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
|
||||||
|
sampler=distributed_sampler)
|
||||||
|
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
|
||||||
|
ds = ds.map(input_columns=["image", "img_id"],
|
||||||
|
output_columns=["image", "image_shape", "img_id"],
|
||||||
|
columns_order=["image", "image_shape", "img_id"],
|
||||||
|
operations=compose_map_func, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
ds = ds.repeat(max_epoch)
|
||||||
|
|
||||||
|
return ds, len(yolo_dataset)
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue