!8357 New add YOLOV4 network into ModelZoo.
From: @linqingke Reviewed-by: @yingjy,@oacjiewen Signed-off-by: @yingjypull/8357/MERGE
commit
c5d9c78e46
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,65 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.yolo import YOLOV4CspDarkNet53
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
def save_air():
|
||||
"""Save mindir file"""
|
||||
print('============= YOLOV4 start save air ==================')
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert ckpt to air')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
|
||||
args = parser.parse_args()
|
||||
network = YOLOV4CspDarkNet53(is_training=False)
|
||||
input_shape = Tensor(tuple([416, 416]), mindspore.float32)
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('yolo_network.'):
|
||||
param_dict_new[key[13:]] = values
|
||||
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('load model {} success'.format(args.pretrained))
|
||||
|
||||
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 416, 416)).astype(np.float32)
|
||||
|
||||
tensor_input_data = Tensor(input_data)
|
||||
export(network, tensor_input_data, input_shape, file_name='yolov4.air', file_format='AIR')
|
||||
|
||||
print("export model success.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
save_air()
|
@ -0,0 +1,22 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from src.yolo import YOLOV4CspDarkNet53
|
||||
|
||||
def create_network(name, *args, **kwargs):
|
||||
if name == "yolov4_cspdarknet53":
|
||||
yolov4_cspdarknet53_net = YOLOV4CspDarkNet53(is_training=False)
|
||||
return yolov4_cspdarknet53_net
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE] [RANK_TABLE_FILE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||
RANK_TABLE_FILE=$(get_real_path $3)
|
||||
echo $DATASET_PATH
|
||||
echo $PRETRAINED_BACKBONE
|
||||
echo $RANK_TABLE_FILE
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $RANK_TABLE_FILE ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$RANK_TABLE_FILE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$RANK_TABLE_FILE
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||
--is_distributed=1 \
|
||||
--lr=0.012 \
|
||||
--t_max=320 \
|
||||
--max_epoch=320 \
|
||||
--warmup_epochs=20 \
|
||||
--per_batch_size=8 \
|
||||
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
CHECKPOINT_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
python eval.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained=$CHECKPOINT_PATH \
|
||||
--testing_shape=416 > log.txt 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
echo $DATASET_PATH
|
||||
PRETRAINED_BACKBONE=$(get_real_path $2)
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python train.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained_backbone=$PRETRAINED_BACKBONE \
|
||||
--is_distributed=0 \
|
||||
--lr=0.012 \
|
||||
--t_max=320 \
|
||||
--max_epoch=320 \
|
||||
--warmup_epochs=4 \
|
||||
--training_shape=416 \
|
||||
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_test.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
DATASET_PATH=$(get_real_path $1)
|
||||
CHECKPOINT_PATH=$(get_real_path $2)
|
||||
echo $DATASET_PATH
|
||||
echo $CHECKPOINT_PATH
|
||||
|
||||
if [ ! -d $DATASET_PATH ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "test" ];
|
||||
then
|
||||
rm -rf ./test
|
||||
fi
|
||||
mkdir ./test
|
||||
cp ../*.py ./test
|
||||
cp -r ../src ./test
|
||||
cd ./test || exit
|
||||
env > env.log
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
python test.py \
|
||||
--data_dir=$DATASET_PATH \
|
||||
--pretrained=$CHECKPOINT_PATH \
|
||||
--testing_shape=416 > log.txt 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,14 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Config parameters for Darknet based yolov4_cspdarknet53 models."""
|
||||
|
||||
|
||||
class ConfigYOLOV4CspDarkNet53:
|
||||
"""
|
||||
Config parameters for the yolov4_cspdarknet53.
|
||||
|
||||
Examples:
|
||||
ConfigYOLOV4CspDarkNet53()
|
||||
"""
|
||||
# train_param
|
||||
# data augmentation related
|
||||
hue = 0.1
|
||||
saturation = 1.5
|
||||
value = 1.5
|
||||
jitter = 0.3
|
||||
|
||||
resize_rate = 10
|
||||
multi_scale = [[416, 416],
|
||||
[448, 448],
|
||||
[480, 480],
|
||||
[512, 512],
|
||||
[544, 544],
|
||||
[576, 576],
|
||||
[608, 608],
|
||||
[640, 640],
|
||||
[672, 672],
|
||||
[704, 704],
|
||||
[736, 736]
|
||||
]
|
||||
|
||||
num_classes = 80
|
||||
max_box = 90
|
||||
|
||||
backbone_input_shape = [32, 64, 128, 256, 512]
|
||||
backbone_shape = [64, 128, 256, 512, 1024]
|
||||
backbone_layers = [1, 2, 8, 8, 4]
|
||||
|
||||
# confidence under ignore_threshold means no object when training
|
||||
ignore_threshold = 0.7
|
||||
|
||||
# h->w
|
||||
anchor_scales = [(12, 16),
|
||||
(19, 36),
|
||||
(40, 28),
|
||||
(36, 75),
|
||||
(76, 55),
|
||||
(72, 146),
|
||||
(142, 110),
|
||||
(192, 243),
|
||||
(459, 401)]
|
||||
out_channel = 3 * (num_classes + 5)
|
||||
|
||||
# test_param
|
||||
test_img_shape = [608, 608]
|
@ -0,0 +1,220 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""DarkNet model."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Mish(nn.Cell):
|
||||
"""Mish activation method"""
|
||||
def __init__(self):
|
||||
super(Mish, self).__init__()
|
||||
self.mul = P.Mul()
|
||||
self.tanh = P.Tanh()
|
||||
self.softplus = P.Softplus()
|
||||
|
||||
def construct(self, input_x):
|
||||
res1 = self.softplus(input_x)
|
||||
tanh = self.tanh(res1)
|
||||
output = self.mul(input_x, tanh)
|
||||
|
||||
return output
|
||||
|
||||
def conv_block(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation=1):
|
||||
"""Get a conv2d batchnorm and relu layer"""
|
||||
pad_mode = 'same'
|
||||
padding = 0
|
||||
|
||||
return nn.SequentialCell(
|
||||
[nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
pad_mode=pad_mode),
|
||||
nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-5),
|
||||
Mish()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Cell):
|
||||
"""
|
||||
DarkNet V1 residual block definition.
|
||||
|
||||
Args:
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
Examples:
|
||||
ResidualBlock(3, 208)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels):
|
||||
|
||||
super(ResidualBlock, self).__init__()
|
||||
out_chls = out_channels
|
||||
self.conv1 = conv_block(in_channels, out_chls, kernel_size=1, stride=1)
|
||||
self.conv2 = conv_block(out_chls, out_channels, kernel_size=3, stride=1)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
out = self.add(out, identity)
|
||||
|
||||
return out
|
||||
|
||||
class CspDarkNet53(nn.Cell):
|
||||
"""
|
||||
DarkNet V1 network.
|
||||
|
||||
Args:
|
||||
block: Cell. Block for network.
|
||||
layer_nums: List. Numbers of different layers.
|
||||
in_channels: Integer. Input channel.
|
||||
out_channels: Integer. Output channel.
|
||||
num_classes: Integer. Class number. Default:100.
|
||||
|
||||
Returns:
|
||||
Tuple, tuple of output tensor,(f1,f2,f3,f4,f5).
|
||||
|
||||
Examples:
|
||||
DarkNet(ResidualBlock)
|
||||
"""
|
||||
def __init__(self,
|
||||
block,
|
||||
detect=False):
|
||||
super(CspDarkNet53, self).__init__()
|
||||
|
||||
self.outchannel = 1024
|
||||
self.detect = detect
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
self.conv0 = conv_block(3, 32, kernel_size=3, stride=1)
|
||||
self.conv1 = conv_block(32, 64, kernel_size=3, stride=2)
|
||||
self.conv2 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv3 = conv_block(64, 32, kernel_size=1, stride=1)
|
||||
self.conv4 = conv_block(32, 64, kernel_size=3, stride=1)
|
||||
self.conv5 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv6 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv7 = conv_block(128, 64, kernel_size=1, stride=1)
|
||||
self.conv8 = conv_block(64, 128, kernel_size=3, stride=2)
|
||||
self.conv9 = conv_block(128, 64, kernel_size=1, stride=1)
|
||||
self.conv10 = conv_block(64, 64, kernel_size=1, stride=1)
|
||||
self.conv11 = conv_block(128, 64, kernel_size=1, stride=1)
|
||||
self.conv12 = conv_block(128, 128, kernel_size=1, stride=1)
|
||||
self.conv13 = conv_block(128, 256, kernel_size=3, stride=2)
|
||||
self.conv14 = conv_block(256, 128, kernel_size=1, stride=1)
|
||||
self.conv15 = conv_block(128, 128, kernel_size=1, stride=1)
|
||||
self.conv16 = conv_block(256, 128, kernel_size=1, stride=1)
|
||||
self.conv17 = conv_block(256, 256, kernel_size=1, stride=1)
|
||||
self.conv18 = conv_block(256, 512, kernel_size=3, stride=2)
|
||||
self.conv19 = conv_block(512, 256, kernel_size=1, stride=1)
|
||||
self.conv20 = conv_block(256, 256, kernel_size=1, stride=1)
|
||||
self.conv21 = conv_block(512, 256, kernel_size=1, stride=1)
|
||||
self.conv22 = conv_block(512, 512, kernel_size=1, stride=1)
|
||||
self.conv23 = conv_block(512, 1024, kernel_size=3, stride=2)
|
||||
self.conv24 = conv_block(1024, 512, kernel_size=1, stride=1)
|
||||
self.conv25 = conv_block(512, 512, kernel_size=1, stride=1)
|
||||
self.conv26 = conv_block(1024, 512, kernel_size=1, stride=1)
|
||||
self.conv27 = conv_block(1024, 1024, kernel_size=1, stride=1)
|
||||
|
||||
self.layer2 = self._make_layer(block, 2, in_channel=64, out_channel=64)
|
||||
self.layer3 = self._make_layer(block, 8, in_channel=128, out_channel=128)
|
||||
self.layer4 = self._make_layer(block, 8, in_channel=256, out_channel=256)
|
||||
self.layer5 = self._make_layer(block, 4, in_channel=512, out_channel=512)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel):
|
||||
"""
|
||||
Make Layer for DarkNet.
|
||||
|
||||
:param block: Cell. DarkNet block.
|
||||
:param layer_num: Integer. Layer number.
|
||||
:param in_channel: Integer. Input channel.
|
||||
:param out_channel: Integer. Output channel.
|
||||
:return: SequentialCell, the output layer.
|
||||
|
||||
Examples:
|
||||
_make_layer(ConvBlock, 1, 128, 256)
|
||||
"""
|
||||
layers = []
|
||||
darkblk = block(in_channel, out_channel)
|
||||
layers.append(darkblk)
|
||||
|
||||
for _ in range(1, layer_num):
|
||||
darkblk = block(out_channel, out_channel)
|
||||
layers.append(darkblk)
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""construct method"""
|
||||
c1 = self.conv0(x)
|
||||
c2 = self.conv1(c1) #route
|
||||
c3 = self.conv2(c2)
|
||||
c4 = self.conv3(c3)
|
||||
c5 = self.conv4(c4)
|
||||
c6 = self.add(c3, c5)
|
||||
c7 = self.conv5(c6)
|
||||
c8 = self.conv6(c2)
|
||||
c9 = self.concat((c7, c8))
|
||||
c10 = self.conv7(c9)
|
||||
c11 = self.conv8(c10) #route
|
||||
c12 = self.conv9(c11)
|
||||
c13 = self.layer2(c12)
|
||||
c14 = self.conv10(c13)
|
||||
c15 = self.conv11(c11)
|
||||
c16 = self.concat((c14, c15))
|
||||
c17 = self.conv12(c16)
|
||||
c18 = self.conv13(c17) #route
|
||||
c19 = self.conv14(c18)
|
||||
c20 = self.layer3(c19)
|
||||
c21 = self.conv15(c20)
|
||||
c22 = self.conv16(c18)
|
||||
c23 = self.concat((c21, c22))
|
||||
c24 = self.conv17(c23) #output1
|
||||
c25 = self.conv18(c24) #route
|
||||
c26 = self.conv19(c25)
|
||||
c27 = self.layer4(c26)
|
||||
c28 = self.conv20(c27)
|
||||
c29 = self.conv21(c25)
|
||||
c30 = self.concat((c28, c29))
|
||||
c31 = self.conv22(c30) #output2
|
||||
c32 = self.conv23(c31) #route
|
||||
c33 = self.conv24(c32)
|
||||
c34 = self.layer5(c33)
|
||||
c35 = self.conv25(c34)
|
||||
c36 = self.conv26(c32)
|
||||
c37 = self.concat((c35, c36))
|
||||
c38 = self.conv27(c37) #output3
|
||||
|
||||
if self.detect:
|
||||
return c24, c31, c38
|
||||
|
||||
return c38
|
||||
|
||||
def get_out_channels(self):
|
||||
return self.outchannel
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Yolo dataset distributed sampler."""
|
||||
from __future__ import division
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
"""Distributed sampler."""
|
||||
def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
print("***********Setting world_size to 1 since it is not passed in ******************")
|
||||
num_replicas = 1
|
||||
if rank is None:
|
||||
print("***********Setting rank to 0 since it is not passed in ******************")
|
||||
rank = 0
|
||||
self.dataset_size = dataset_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
if self.shuffle:
|
||||
indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
|
||||
# np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
|
||||
indices = indices.tolist()
|
||||
self.epoch += 1
|
||||
# change to list type
|
||||
else:
|
||||
indices = list(range(self.dataset_size))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
@ -0,0 +1,204 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parameter init."""
|
||||
import math
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
from mindspore.common import initializer as init
|
||||
from mindspore.common.initializer import Initializer as MeInitializer
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.nn as nn
|
||||
from .util import load_backbone
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
r"""Return the recommended gain value for the given nonlinearity function.
|
||||
The values are as follows:
|
||||
|
||||
================= ====================================================
|
||||
nonlinearity gain
|
||||
================= ====================================================
|
||||
Linear / Identity :math:`1`
|
||||
Conv{1,2,3}D :math:`1`
|
||||
Sigmoid :math:`1`
|
||||
Tanh :math:`\frac{5}{3}`
|
||||
ReLU :math:`\sqrt{2}`
|
||||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
|
||||
================= ====================================================
|
||||
|
||||
Args:
|
||||
nonlinearity: the non-linear function (`nn.functional` name)
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
return 1
|
||||
if nonlinearity == 'tanh':
|
||||
return 5.0 / 3
|
||||
if nonlinearity == 'relu':
|
||||
return math.sqrt(2.0)
|
||||
if nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
# True/False are instances of int, hence check above
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
|
||||
def _assignment(arr, num):
|
||||
"""Assign the value of 'num' and 'arr'."""
|
||||
if arr.shape == ():
|
||||
arr = arr.reshape((1))
|
||||
arr[:] = num
|
||||
arr = arr.reshape(())
|
||||
else:
|
||||
if isinstance(num, np.ndarray):
|
||||
arr[:] = num[:]
|
||||
else:
|
||||
arr[:] = num
|
||||
return arr
|
||||
|
||||
|
||||
def _calculate_correct_fan(array, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(array)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Delving deep into rectifiers: Surpassing human-level
|
||||
performance on ImageNet classification` - He, K. et al. (2015), using a
|
||||
uniform distribution. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
||||
|
||||
.. math::
|
||||
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
||||
|
||||
Also known as He initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `Tensor`
|
||||
a: the negative slope of the rectifier used after this layer (only
|
||||
used with ``'leaky_relu'``)
|
||||
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
||||
preserves the magnitude of the variance of the weights in the
|
||||
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
||||
backwards pass.
|
||||
nonlinearity: the non-linear function (`nn.functional` name),
|
||||
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
|
||||
|
||||
Examples:
|
||||
>>> w = np.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
"""
|
||||
fan = _calculate_correct_fan(arr, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
return np.random.uniform(-bound, bound, arr.shape)
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(arr):
|
||||
"""Calculate fan in and fan out."""
|
||||
dimensions = len(arr.shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = arr.shape[1]
|
||||
num_output_fmaps = arr.shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = reduce(lambda x, y: x * y, arr.shape[2:])
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
class KaimingUniform(MeInitializer):
|
||||
"""Kaiming uniform initializer."""
|
||||
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
super(KaimingUniform, self).__init__()
|
||||
self.a = a
|
||||
self.mode = mode
|
||||
self.nonlinearity = nonlinearity
|
||||
|
||||
def _initialize(self, arr):
|
||||
tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity)
|
||||
_assignment(arr, tmp)
|
||||
|
||||
|
||||
def default_recurisive_init(custom_cell):
|
||||
"""Initialize parameter."""
|
||||
for _, cell in custom_cell.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype))
|
||||
elif isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if cell.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
cell.bias.set_data(init.initializer(init.Uniform(bound),
|
||||
cell.bias.shape,
|
||||
cell.bias.dtype))
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
pass
|
||||
|
||||
def load_yolov4_params(args, network):
|
||||
"""Load yolov4 cspdarknet parameter from checkpoint."""
|
||||
if args.pretrained_backbone:
|
||||
network = load_backbone(network, args.pretrained_backbone, args)
|
||||
args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone))
|
||||
else:
|
||||
args.logger.info('Not load pre-trained backbone, please be careful')
|
||||
|
||||
if args.resume_yolov4:
|
||||
param_dict = load_checkpoint(args.resume_yolov4)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('yolo_network.'):
|
||||
param_dict_new[key[13:]] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
args.logger.info('in resume {}'.format(key))
|
||||
|
||||
args.logger.info('resume finished')
|
||||
load_param_into_net(network, param_dict_new)
|
||||
args.logger.info('load_model {} success'.format(args.resume_yolov4))
|
@ -0,0 +1,80 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
Logger.
|
||||
|
||||
Args:
|
||||
logger_name: String. Logger name.
|
||||
rank: Integer. Rank id.
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
"""Get Logger."""
|
||||
logger = LOGGER('yolov4_cspdarknet53', rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
@ -0,0 +1,70 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""YOLOV4 loss."""
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class XYLoss(nn.Cell):
|
||||
"""Loss for x and y."""
|
||||
def __init__(self):
|
||||
super(XYLoss, self).__init__()
|
||||
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, box_loss_scale, predict_xy, true_xy):
|
||||
xy_loss = object_mask * box_loss_scale * self.cross_entropy(predict_xy, true_xy)
|
||||
xy_loss = self.reduce_sum(xy_loss, ())
|
||||
return xy_loss
|
||||
|
||||
|
||||
class WHLoss(nn.Cell):
|
||||
"""Loss for w and h."""
|
||||
def __init__(self):
|
||||
super(WHLoss, self).__init__()
|
||||
self.square = P.Square()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, box_loss_scale, predict_wh, true_wh):
|
||||
wh_loss = object_mask * box_loss_scale * 0.5 * P.Square()(true_wh - predict_wh)
|
||||
wh_loss = self.reduce_sum(wh_loss, ())
|
||||
return wh_loss
|
||||
|
||||
|
||||
class ConfidenceLoss(nn.Cell):
|
||||
"""Loss for confidence."""
|
||||
def __init__(self):
|
||||
super(ConfidenceLoss, self).__init__()
|
||||
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, predict_confidence, ignore_mask):
|
||||
confidence_loss = self.cross_entropy(predict_confidence, object_mask)
|
||||
confidence_loss = object_mask * confidence_loss + (1 - object_mask) * confidence_loss * ignore_mask
|
||||
confidence_loss = self.reduce_sum(confidence_loss, ())
|
||||
return confidence_loss
|
||||
|
||||
|
||||
class ClassLoss(nn.Cell):
|
||||
"""Loss for classification."""
|
||||
def __init__(self):
|
||||
super(ClassLoss, self).__init__()
|
||||
self.cross_entropy = P.SigmoidCrossEntropyWithLogits()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, object_mask, predict_class, class_probs):
|
||||
class_loss = object_mask * self.cross_entropy(predict_class, class_probs)
|
||||
class_loss = self.reduce_sum(class_loss, ())
|
||||
return class_loss
|
@ -0,0 +1,180 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate scheduler."""
|
||||
import math
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
"""Linear learning rate."""
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""Warmup step learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma**milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
|
||||
"""Cosine annealing learning rate."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr_v2(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
|
||||
"""Cosine annealing learning rate V2."""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
last_lr = 0
|
||||
last_epoch_v1 = 0
|
||||
|
||||
t_max_v2 = int(max_epoch*1/3)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
if i < total_steps*2/3:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
|
||||
last_lr = lr
|
||||
last_epoch_v1 = last_epoch
|
||||
else:
|
||||
base_lr = last_lr
|
||||
last_epoch = last_epoch - last_epoch_v1
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / t_max_v2)) / 2
|
||||
|
||||
lr_each_step.append(lr)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
|
||||
"""Warmup cosine annealing learning rate."""
|
||||
start_sample_epoch = 60
|
||||
step_sample = 2
|
||||
tobe_sampled_epoch = 60
|
||||
end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch
|
||||
max_sampled_epoch = max_epoch+tobe_sampled_epoch
|
||||
t_max = max_sampled_epoch
|
||||
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
total_sampled_steps = int(max_sampled_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
|
||||
for i in range(total_sampled_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample):
|
||||
continue
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
assert total_steps == len(lr_each_step)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def get_lr(args):
|
||||
"""generate learning rate."""
|
||||
if args.lr_scheduler == 'exponential':
|
||||
lr = warmup_step_lr(args.lr,
|
||||
args.lr_epochs,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
gamma=args.lr_gamma,
|
||||
)
|
||||
elif args.lr_scheduler == 'cosine_annealing':
|
||||
lr = warmup_cosine_annealing_lr(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_V2':
|
||||
lr = warmup_cosine_annealing_lr_v2(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
elif args.lr_scheduler == 'cosine_annealing_sample':
|
||||
lr = warmup_cosine_annealing_lr_sample(args.lr,
|
||||
args.steps_per_epoch,
|
||||
args.warmup_epochs,
|
||||
args.max_epoch,
|
||||
args.t_max,
|
||||
args.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(args.lr_scheduler)
|
||||
return lr
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,188 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Util class or function."""
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from .yolo import YoloLossBlock
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
def load_backbone(net, ckpt_path, args):
|
||||
"""Load cspdarknet53 backbone checkpoint."""
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
yolo_backbone_prefix = 'feature_map.backbone'
|
||||
darknet_backbone_prefix = 'backbone'
|
||||
find_param = []
|
||||
not_found_param = []
|
||||
net.init_parameters_data()
|
||||
for name, cell in net.cells_and_names():
|
||||
if name.startswith(yolo_backbone_prefix):
|
||||
name = name.replace(yolo_backbone_prefix, darknet_backbone_prefix)
|
||||
if isinstance(cell, (nn.Conv2d, nn.Dense)):
|
||||
darknet_weight = '{}.weight'.format(name)
|
||||
darknet_bias = '{}.bias'.format(name)
|
||||
if darknet_weight in param_dict:
|
||||
cell.weight.set_data(param_dict[darknet_weight].data)
|
||||
find_param.append(darknet_weight)
|
||||
else:
|
||||
not_found_param.append(darknet_weight)
|
||||
if darknet_bias in param_dict:
|
||||
cell.bias.set_data(param_dict[darknet_bias].data)
|
||||
find_param.append(darknet_bias)
|
||||
else:
|
||||
not_found_param.append(darknet_bias)
|
||||
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
darknet_moving_mean = '{}.moving_mean'.format(name)
|
||||
darknet_moving_variance = '{}.moving_variance'.format(name)
|
||||
darknet_gamma = '{}.gamma'.format(name)
|
||||
darknet_beta = '{}.beta'.format(name)
|
||||
if darknet_moving_mean in param_dict:
|
||||
cell.moving_mean.set_data(param_dict[darknet_moving_mean].data)
|
||||
find_param.append(darknet_moving_mean)
|
||||
else:
|
||||
not_found_param.append(darknet_moving_mean)
|
||||
if darknet_moving_variance in param_dict:
|
||||
cell.moving_variance.set_data(param_dict[darknet_moving_variance].data)
|
||||
find_param.append(darknet_moving_variance)
|
||||
else:
|
||||
not_found_param.append(darknet_moving_variance)
|
||||
if darknet_gamma in param_dict:
|
||||
cell.gamma.set_data(param_dict[darknet_gamma].data)
|
||||
find_param.append(darknet_gamma)
|
||||
else:
|
||||
not_found_param.append(darknet_gamma)
|
||||
if darknet_beta in param_dict:
|
||||
cell.beta.set_data(param_dict[darknet_beta].data)
|
||||
find_param.append(darknet_beta)
|
||||
else:
|
||||
not_found_param.append(darknet_beta)
|
||||
|
||||
args.logger.info('================found_param {}========='.format(len(find_param)))
|
||||
args.logger.info(find_param)
|
||||
args.logger.info('================not_found_param {}========='.format(len(not_found_param)))
|
||||
args.logger.info(not_found_param)
|
||||
args.logger.info('=====load {} successfully ====='.format(ckpt_path))
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def default_wd_filter(x):
|
||||
"""default weight decay filter."""
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
return False
|
||||
if parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
return False
|
||||
if parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
"""Param groups for optimizer."""
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
class ShapeRecord:
|
||||
"""Log image shape."""
|
||||
def __init__(self):
|
||||
self.shape_record = {
|
||||
416: 0,
|
||||
448: 0,
|
||||
480: 0,
|
||||
512: 0,
|
||||
544: 0,
|
||||
576: 0,
|
||||
608: 0,
|
||||
640: 0,
|
||||
672: 0,
|
||||
704: 0,
|
||||
736: 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
def set(self, shape):
|
||||
if len(shape) > 1:
|
||||
shape = shape[0]
|
||||
shape = int(shape)
|
||||
self.shape_record[shape] += 1
|
||||
self.shape_record['total'] += 1
|
||||
|
||||
def show(self, logger):
|
||||
for key in self.shape_record:
|
||||
rate = self.shape_record[key] / float(self.shape_record['total'])
|
||||
logger.info('shape {}: {:.2f}%'.format(key, rate*100))
|
||||
|
||||
|
||||
def keep_loss_fp32(network):
|
||||
"""Keep loss of network with float32"""
|
||||
for _, cell in network.cells_and_names():
|
||||
if isinstance(cell, (YoloLossBlock,)):
|
||||
cell.to_float(mstype.float32)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue