parent
6cf308076d
commit
09c23cc82c
@ -0,0 +1,215 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face detection eval."""
|
||||
import os
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
|
||||
|
||||
|
||||
|
||||
from src.data_preprocess import SingleScaleTrans
|
||||
from src.config import config
|
||||
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
|
||||
from src.FaceDetection import voc_wrapper
|
||||
from src.network_define import BuildTestNetwork, get_bounding_boxes, tensor_to_brambox, \
|
||||
parse_gt_from_anno, parse_rets, calc_recall_presicion_ap
|
||||
|
||||
plt.switch_backend('agg')
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
parser = argparse.ArgumentParser('Yolov3 Face Detection')
|
||||
|
||||
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
|
||||
parser.add_argument('--world_size', type=int, default=1, help='current process number to support distributed')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def val(args):
|
||||
'''eval'''
|
||||
print('=============yolov3 start evaluating==================')
|
||||
|
||||
# logger
|
||||
args.batch_size = config.batch_size
|
||||
args.input_shape = config.input_shape
|
||||
args.result_path = config.result_path
|
||||
args.conf_thresh = config.conf_thresh
|
||||
args.nms_thresh = config.nms_thresh
|
||||
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.STAND_ALONE, device_num=args.world_size,
|
||||
gradients_mean=True)
|
||||
mindrecord_path = args.mindrecord_path
|
||||
print('Loading data from {}'.format(mindrecord_path))
|
||||
|
||||
num_classes = config.num_classes
|
||||
if num_classes > 1:
|
||||
raise NotImplementedError('num_classes > 1: Yolov3 postprocess not implemented!')
|
||||
|
||||
anchors = config.anchors
|
||||
anchors_mask = config.anchors_mask
|
||||
num_anchors_list = [len(x) for x in anchors_mask]
|
||||
|
||||
reduction_0 = 64.0
|
||||
reduction_1 = 32.0
|
||||
reduction_2 = 16.0
|
||||
labels = ['face']
|
||||
classes = {0: 'face'}
|
||||
|
||||
# dataloader
|
||||
ds = de.MindDataset(mindrecord_path + "0", columns_list=["image", "annotation", "image_name", "image_size"])
|
||||
|
||||
single_scale_trans = SingleScaleTrans(resize=args.input_shape)
|
||||
|
||||
ds = ds.batch(args.batch_size, per_batch_map=single_scale_trans,
|
||||
input_columns=["image", "annotation", "image_name", "image_size"], num_parallel_workers=8)
|
||||
|
||||
args.steps_per_epoch = ds.get_dataset_size()
|
||||
|
||||
# backbone
|
||||
network = backbone_HwYolov3(num_classes, num_anchors_list, args)
|
||||
|
||||
# load pretrain model
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('load model {} success'.format(args.pretrained))
|
||||
else:
|
||||
print('load model {} failed, please check the path of model, evaluating end'.format(args.pretrained))
|
||||
exit(0)
|
||||
|
||||
ds = ds.repeat(1)
|
||||
|
||||
det = {}
|
||||
img_size = {}
|
||||
img_anno = {}
|
||||
|
||||
model_name = args.pretrained.split('/')[-1].replace('.ckpt', '')
|
||||
result_path = os.path.join(args.result_path, model_name)
|
||||
if os.path.exists(result_path):
|
||||
pass
|
||||
if not os.path.isdir(result_path):
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
|
||||
# result file
|
||||
ret_files_set = {
|
||||
'face': os.path.join(result_path, 'comp4_det_test_face_rm5050.txt'),
|
||||
}
|
||||
|
||||
test_net = BuildTestNetwork(network, reduction_0, reduction_1, reduction_2, anchors, anchors_mask, num_classes,
|
||||
args)
|
||||
|
||||
print('conf_thresh:', args.conf_thresh)
|
||||
|
||||
eval_times = 0
|
||||
|
||||
for data in ds.create_tuple_iterator(output_numpy=True):
|
||||
batch_images = data[0]
|
||||
batch_labels = data[1]
|
||||
batch_image_name = data[2]
|
||||
batch_image_size = data[3]
|
||||
eval_times += 1
|
||||
|
||||
img_tensor = Tensor(batch_images, mstype.float32)
|
||||
|
||||
dets = []
|
||||
tdets = []
|
||||
|
||||
coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2, cls_scores_2 = test_net(img_tensor)
|
||||
|
||||
boxes_0, boxes_1, boxes_2 = get_bounding_boxes(coords_0, cls_scores_0, coords_1, cls_scores_1, coords_2,
|
||||
cls_scores_2, args.conf_thresh, args.input_shape,
|
||||
num_classes)
|
||||
|
||||
converted_boxes_0, converted_boxes_1, converted_boxes_2 = tensor_to_brambox(boxes_0, boxes_1, boxes_2,
|
||||
args.input_shape, labels)
|
||||
|
||||
tdets.append(converted_boxes_0)
|
||||
tdets.append(converted_boxes_1)
|
||||
tdets.append(converted_boxes_2)
|
||||
|
||||
batch = len(tdets[0])
|
||||
for b in range(batch):
|
||||
single_dets = []
|
||||
for op in range(3):
|
||||
single_dets.extend(tdets[op][b])
|
||||
dets.append(single_dets)
|
||||
|
||||
det.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(dets)})
|
||||
img_size.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_image_size)})
|
||||
img_anno.update({batch_image_name[k].decode('UTF-8'): v for k, v in enumerate(batch_labels)})
|
||||
|
||||
print('eval times:', eval_times)
|
||||
print('batch size: ', args.batch_size)
|
||||
|
||||
netw, neth = args.input_shape
|
||||
reorg_dets = voc_wrapper.reorg_detection(det, netw, neth, img_size)
|
||||
voc_wrapper.gen_results(reorg_dets, result_path, img_size, args.nms_thresh)
|
||||
|
||||
# compute mAP
|
||||
ground_truth = parse_gt_from_anno(img_anno, classes)
|
||||
|
||||
ret_list = parse_rets(ret_files_set)
|
||||
iou_thr = 0.5
|
||||
evaluate = calc_recall_presicion_ap(ground_truth, ret_list, iou_thr)
|
||||
|
||||
aps_str = ''
|
||||
for cls in evaluate:
|
||||
per_line, = plt.plot(evaluate[cls]['recall'], evaluate[cls]['presicion'], 'b-')
|
||||
per_line.set_label('%s:AP=%.3f' % (cls, evaluate[cls]['ap']))
|
||||
aps_str += '_%s_AP_%.3f' % (cls, evaluate[cls]['ap'])
|
||||
plt.plot([i / 1000.0 for i in range(1, 1001)], [i / 1000.0 for i in range(1, 1001)], 'y--')
|
||||
plt.axis([0, 1.2, 0, 1.2])
|
||||
plt.xlabel('recall')
|
||||
plt.ylabel('precision')
|
||||
plt.grid()
|
||||
|
||||
plt.legend()
|
||||
plt.title('PR')
|
||||
|
||||
# save mAP
|
||||
ap_save_path = os.path.join(result_path, result_path.replace('/', '_') + aps_str + '.png')
|
||||
print('Saving {}'.format(ap_save_path))
|
||||
plt.savefig(ap_save_path)
|
||||
|
||||
print('=============yolov3 evaluating finished==================')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arg = parse_args()
|
||||
val(arg)
|
@ -0,0 +1,70 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert ckpt to air."""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
|
||||
|
||||
from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3
|
||||
from src.config import config
|
||||
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
||||
|
||||
|
||||
def save_air(args):
|
||||
'''save air'''
|
||||
print('============= yolov3 start save air ==================')
|
||||
|
||||
|
||||
num_classes = config.num_classes
|
||||
anchors_mask = config.anchors_mask
|
||||
num_anchors_list = [len(x) for x in anchors_mask]
|
||||
|
||||
network = backbone_HwYolov3(num_classes, num_anchors_list, args)
|
||||
|
||||
if os.path.isfile(args.pretrained):
|
||||
param_dict = load_checkpoint(args.pretrained)
|
||||
param_dict_new = {}
|
||||
for key, values in param_dict.items():
|
||||
if key.startswith('moments.'):
|
||||
continue
|
||||
elif key.startswith('network.'):
|
||||
param_dict_new[key[8:]] = values
|
||||
else:
|
||||
param_dict_new[key] = values
|
||||
load_param_into_net(network, param_dict_new)
|
||||
print('load model {} success'.format(args.pretrained))
|
||||
|
||||
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 448, 768)).astype(np.float32)
|
||||
|
||||
tensor_input_data = Tensor(input_data)
|
||||
export(network, tensor_input_data,
|
||||
file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR')
|
||||
|
||||
print("export model success.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Convert ckpt to air')
|
||||
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
|
||||
arg = parser.parse_args()
|
||||
save_air(arg)
|
@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 -a $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_distribute_train.sh [MINDRECORD_FILE] [RANK_TABLE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
rm -rf ${current_exec_path}/device*
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
RANK_TABLE=$(get_real_path $2)
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $RANK_TABLE
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
export RANK_TABLE_FILE=$RANK_TABLE
|
||||
export RANK_SIZE=8
|
||||
|
||||
echo 'start training'
|
||||
for((i=0;i<=$RANK_SIZE-1;i++));
|
||||
do
|
||||
echo 'start rank '$i
|
||||
mkdir ${current_exec_path}/device$i
|
||||
cd ${current_exec_path}/device$i
|
||||
export RANK_ID=$i
|
||||
dev=`expr $i + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
done
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='eval.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start evaluating'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > eval.log 2>&1 &
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,71 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_export.sh [BATCH_SIZE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='export.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
BATCH_SIZE=$1
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo $BATCH_SIZE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start converting'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > convert.log 2>&1 &
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,77 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 -a $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID] [PRETRAINED_BACKBONE]"
|
||||
echo " or: sh run_standalone_train.sh [MINDRECORD_FILE] [USE_DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
dirname_path=$(dirname $(pwd))
|
||||
echo ${dirname_path}
|
||||
|
||||
export PYTHONPATH=${dirname_path}:$PYTHONPATH
|
||||
|
||||
export RANK_SIZE=1
|
||||
|
||||
SCRIPT_NAME='train.py'
|
||||
|
||||
ulimit -c unlimited
|
||||
|
||||
MINDRECORD_FILE=$(get_real_path $1)
|
||||
USE_DEVICE_ID=$2
|
||||
PRETRAINED_BACKBONE=''
|
||||
|
||||
if [ $# == 3 ]
|
||||
then
|
||||
PRETRAINED_BACKBONE=$(get_real_path $3)
|
||||
if [ ! -f $PRETRAINED_BACKBONE ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo $MINDRECORD_FILE
|
||||
echo $USE_DEVICE_ID
|
||||
echo $PRETRAINED_BACKBONE
|
||||
|
||||
echo 'start training'
|
||||
export RANK_ID=0
|
||||
rm -rf ${current_exec_path}/device$USE_DEVICE_ID
|
||||
echo 'start device '$USE_DEVICE_ID
|
||||
mkdir ${current_exec_path}/device$USE_DEVICE_ID
|
||||
cd ${current_exec_path}/device$USE_DEVICE_ID
|
||||
dev=`expr $USE_DEVICE_ID + 0`
|
||||
export DEVICE_ID=$dev
|
||||
python ${dirname_path}/${SCRIPT_NAME} \
|
||||
--world_size=1 \
|
||||
--mindrecord_path=$MINDRECORD_FILE \
|
||||
--pretrained=$PRETRAINED_BACKBONE > train.log 2>&1 &
|
||||
|
||||
echo 'running'
|
@ -0,0 +1,126 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face detection compute final result."""
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def remove_5050_face(dst_txt, img_size):
|
||||
'''remove_5050_face'''
|
||||
dst_txt_rm5050 = dst_txt.replace('.txt', '') + '_rm5050.txt'
|
||||
if os.path.exists(dst_txt_rm5050):
|
||||
os.remove(dst_txt_rm5050)
|
||||
|
||||
write_lines = []
|
||||
with open(dst_txt, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
info = line.replace('\n', '').split(' ')
|
||||
img_name = info[0]
|
||||
size = img_size[img_name][0]
|
||||
w = float(info[4]) - float(info[2])
|
||||
h = float(info[5]) - float(info[3])
|
||||
radio = max(float(size[0]) / 1920., float(size[1]) / 1080.)
|
||||
new_w = float(w) / radio
|
||||
new_h = float(h) / radio
|
||||
if min(new_w, new_h) >= 50.:
|
||||
write_lines.append(line)
|
||||
|
||||
file.close()
|
||||
|
||||
with open(dst_txt_rm5050, 'a') as fw:
|
||||
for line in write_lines:
|
||||
fw.write(line)
|
||||
|
||||
|
||||
def nms(boxes, threshold=0.5):
|
||||
'''NMS.'''
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
scores = boxes[:, 4]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
reserved_boxes = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
reserved_boxes.append(i)
|
||||
max_x1 = np.maximum(x1[i], x1[order[1:]])
|
||||
max_y1 = np.maximum(y1[i], y1[order[1:]])
|
||||
min_x2 = np.minimum(x2[i], x2[order[1:]])
|
||||
min_y2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
intersect_w = np.maximum(0.0, min_x2 - max_x1 + 1)
|
||||
intersect_h = np.maximum(0.0, min_y2 - max_y1 + 1)
|
||||
intersect_area = intersect_w * intersect_h
|
||||
ovr = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
|
||||
|
||||
indexs = np.where(ovr <= threshold)[0]
|
||||
order = order[indexs + 1]
|
||||
|
||||
return reserved_boxes
|
||||
|
||||
def gen_results(reorg_dets, results_folder, img_size, nms_thresh=0.45):
|
||||
'''gen_results'''
|
||||
for label, pieces in reorg_dets.items():
|
||||
ret = []
|
||||
dst_fp = '%s/comp4_det_test_%s.txt' % (results_folder, label)
|
||||
for name in pieces.keys():
|
||||
pred = np.array(pieces[name], dtype=np.float32)
|
||||
keep = nms(pred, nms_thresh)
|
||||
for ik in keep:
|
||||
line = '%s %f %s' % (name, pred[ik][-1], ' '.join([str(num) for num in pred[ik][:4]]))
|
||||
ret.append(line)
|
||||
|
||||
with open(dst_fp, 'w') as fd:
|
||||
fd.write('\n'.join(ret))
|
||||
|
||||
remove_5050_face(dst_fp, img_size)
|
||||
|
||||
|
||||
def reorg_detection(dets, netw, neth, img_sizes):
|
||||
'''reorg_detection'''
|
||||
reorg_dets = {}
|
||||
for k, v in dets.items():
|
||||
name = k
|
||||
orig_width, orig_height = img_sizes[k][0]
|
||||
scale = min(float(netw)/orig_width, float(neth)/orig_height)
|
||||
new_width = orig_width * scale
|
||||
new_height = orig_height * scale
|
||||
pad_w = (netw - new_width) / 2.0
|
||||
pad_h = (neth - new_height) / 2.0
|
||||
|
||||
for iv in v:
|
||||
xmin = iv.x_top_left
|
||||
ymin = iv.y_top_left
|
||||
xmax = xmin + iv.width
|
||||
ymax = ymin + iv.height
|
||||
conf = iv.confidence
|
||||
class_label = iv.class_label
|
||||
|
||||
xmin = max(0, float(xmin - pad_w)/scale)
|
||||
xmax = min(orig_width - 1, float(xmax - pad_w)/scale)
|
||||
ymin = max(0, float(ymin - pad_h)/scale)
|
||||
ymax = min(orig_height - 1, float(ymax - pad_h)/scale)
|
||||
|
||||
reorg_dets.setdefault(class_label, {})
|
||||
reorg_dets[class_label].setdefault(name, [])
|
||||
piece = (xmin, ymin, xmax, ymax, conf)
|
||||
reorg_dets[class_label][name].append(piece)
|
||||
|
||||
return reorg_dets
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,125 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face detection yolov3 post-process."""
|
||||
import numpy as np
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
||||
class PtLinspace(Cell):
|
||||
'''PtLinspace'''
|
||||
def __init__(self):
|
||||
super(PtLinspace, self).__init__()
|
||||
self.tuple_to_array = P.TupleToArray()
|
||||
|
||||
def construct(self, start, end, steps):
|
||||
lin_x = ()
|
||||
step = (end - start + 1) / steps
|
||||
for i in range(start, end + 1, step):
|
||||
lin_x += (i,)
|
||||
lin_x = self.tuple_to_array(lin_x)
|
||||
return lin_x
|
||||
|
||||
|
||||
class YoloPostProcess(Cell):
|
||||
"""
|
||||
Yolov3 post-process of network output.
|
||||
"""
|
||||
def __init__(self, num_classes, cur_anchors, conf_thresh, network_size, reduction, anchors_mask):
|
||||
super(YoloPostProcess, self).__init__()
|
||||
self.print = P.Print()
|
||||
self.num_classes = num_classes
|
||||
self.anchors = cur_anchors
|
||||
self.conf_thresh = conf_thresh
|
||||
self.network_size = network_size
|
||||
self.reduction = reduction
|
||||
self.anchors_mask = anchors_mask
|
||||
self.num_anchors = len(anchors_mask)
|
||||
|
||||
anchors_w = []
|
||||
anchors_h = []
|
||||
for i in range(len(self.anchors_mask)):
|
||||
anchors_w.append(self.anchors[i][0])
|
||||
anchors_h.append(self.anchors[i][1])
|
||||
self.anchors_w = Tensor(np.array(anchors_w).reshape((1, len(self.anchors_mask), 1)))
|
||||
self.anchors_h = Tensor(np.array(anchors_h).reshape((1, len(self.anchors_mask), 1)))
|
||||
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.sigmoid = P.Sigmoid()
|
||||
self.cast = P.Cast()
|
||||
self.exp = P.Exp()
|
||||
self.concat3 = P.Concat(3)
|
||||
self.tile = P.Tile()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.pt_linspace = PtLinspace()
|
||||
|
||||
def construct(self, output):
|
||||
'''construct'''
|
||||
output_d = self.shape(output)
|
||||
num_batch = output_d[0]
|
||||
num_anchors = self.num_anchors
|
||||
|
||||
num_channels = output_d[1] / num_anchors
|
||||
height = output_d[2]
|
||||
width = output_d[3]
|
||||
|
||||
lin_x = self.pt_linspace(0, width - 1, width)
|
||||
lin_x = self.tile(lin_x, (height,))
|
||||
lin_x = self.cast(lin_x, mstype.float32)
|
||||
|
||||
lin_y = self.pt_linspace(0, height - 1, height)
|
||||
lin_y = self.reshape(lin_y, (height, 1))
|
||||
lin_y = self.tile(lin_y, (1, width))
|
||||
lin_y = self.reshape(lin_y, (self.shape(lin_y)[0] * self.shape(lin_y)[1],))
|
||||
lin_y = self.cast(lin_y, mstype.float32)
|
||||
|
||||
anchor_w = self.anchors_w
|
||||
anchor_h = self.anchors_h
|
||||
anchor_w = self.cast(anchor_w, mstype.float32)
|
||||
anchor_h = self.cast(anchor_h, mstype.float32)
|
||||
|
||||
output = self.reshape(output, (num_batch, num_anchors, num_channels, height * width))
|
||||
|
||||
coord_x = (self.sigmoid(output[:, :, 0, :]) + lin_x) / width
|
||||
coord_y = (self.sigmoid(output[:, :, 1, :]) + lin_y) / height
|
||||
coord_w = self.exp(output[:, :, 2, :]) * anchor_w / width
|
||||
coord_h = self.exp(output[:, :, 3, :]) * anchor_h / height
|
||||
obj_conf = self.sigmoid(output[:, :, 4, :])
|
||||
|
||||
cls_conf = 0.0
|
||||
|
||||
if self.num_classes > 1:
|
||||
# num_classes > 1: not implemented!
|
||||
pass
|
||||
|
||||
else:
|
||||
cls_conf = self.sigmoid(output[:, :, 4, :])
|
||||
|
||||
cls_scores = obj_conf * cls_conf
|
||||
|
||||
coord_x_t = self.expand_dims(coord_x, 3)
|
||||
coord_y_t = self.expand_dims(coord_y, 3)
|
||||
coord_w_t = self.expand_dims(coord_w, 3)
|
||||
coord_h_t = self.expand_dims(coord_h, 3)
|
||||
|
||||
coord_1 = self.concat3((coord_x_t, coord_y_t))
|
||||
coord_2 = self.concat3((coord_w_t, coord_h_t))
|
||||
coords = self.concat3((coord_1, coord_2))
|
||||
|
||||
return coords, cls_scores
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,58 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Network config setting, will be used in train.py and eval.py"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
'batch_size': 64,
|
||||
'warmup_lr': 0.0004,
|
||||
'lr_rates': [0.002, 0.004, 0.002, 0.0008, 0.0004, 0.0002, 0.00008, 0.00004, 0.000004],
|
||||
'lr_steps': [1000, 10000, 40000, 60000, 80000, 100000, 130000, 160000, 190000],
|
||||
'gamma': 0.5,
|
||||
'weight_decay': 0.0005,
|
||||
'momentum': 0.5,
|
||||
'max_epoch': 2500,
|
||||
|
||||
'log_interval': 10,
|
||||
'ckpt_path': '../../output',
|
||||
'ckpt_interval': 1000,
|
||||
'result_path': '../../results',
|
||||
|
||||
'input_shape': [768, 448],
|
||||
'jitter': 0.3,
|
||||
'flip': 0.5,
|
||||
'hue': 0.1,
|
||||
'sat': 1.5,
|
||||
'val': 1.5,
|
||||
'num_classes': 1,
|
||||
'anchors': [
|
||||
[3, 4],
|
||||
[5, 6],
|
||||
[7, 9],
|
||||
[10, 13],
|
||||
[15, 19],
|
||||
[21, 26],
|
||||
[28, 36],
|
||||
[38, 49],
|
||||
[54, 71],
|
||||
[77, 102],
|
||||
[122, 162],
|
||||
[207, 268],
|
||||
],
|
||||
'anchors_mask': [(8, 9, 10, 11), (4, 5, 6, 7), (0, 1, 2, 3)],
|
||||
|
||||
'conf_thresh': 0.1,
|
||||
'nms_thresh': 0.45,
|
||||
})
|
@ -0,0 +1,244 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face detection yolov3 data pre-process."""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset.vision.py_transforms as P
|
||||
|
||||
from src.transforms import RandomCropLetterbox, RandomFlip, HSVShift, ResizeLetterbox
|
||||
from src.config import config
|
||||
|
||||
|
||||
class SingleScaleTrans:
|
||||
'''SingleScaleTrans'''
|
||||
def __init__(self, resize, max_anno_count=200):
|
||||
self.resize = (resize[0], resize[1])
|
||||
self.max_anno_count = max_anno_count
|
||||
|
||||
def __call__(self, imgs, ann, image_names, image_size, batch_info):
|
||||
|
||||
size = self.resize
|
||||
decode = P.Decode()
|
||||
resize_letter_box_op = ResizeLetterbox(input_dim=size)
|
||||
|
||||
to_tensor = P.ToTensor()
|
||||
ret_imgs = []
|
||||
ret_anno = []
|
||||
|
||||
for i, image in enumerate(imgs):
|
||||
img_pil = decode(image)
|
||||
input_data = img_pil, ann[i]
|
||||
input_data = resize_letter_box_op(*input_data)
|
||||
image_arr = to_tensor(input_data[0])
|
||||
ret_imgs.append(image_arr)
|
||||
ret_anno.append(input_data[1])
|
||||
|
||||
for i, anno in enumerate(ret_anno):
|
||||
anno_count = anno.shape[0]
|
||||
if anno_count < self.max_anno_count:
|
||||
ret_anno[i] = np.concatenate(
|
||||
(ret_anno[i], np.zeros((self.max_anno_count - anno_count, 6), dtype=float)), axis=0)
|
||||
else:
|
||||
ret_anno[i] = ret_anno[i][:self.max_anno_count]
|
||||
|
||||
return np.array(ret_imgs), np.array(ret_anno), image_names, image_size
|
||||
|
||||
|
||||
def check_gt_negative_or_empty(gt):
|
||||
new_gt = []
|
||||
for anno in gt:
|
||||
for data in anno:
|
||||
if data not in (0, -1):
|
||||
new_gt.append(anno)
|
||||
break
|
||||
if not new_gt:
|
||||
return True, new_gt
|
||||
return False, new_gt
|
||||
|
||||
|
||||
def bbox_ious_numpy(boxes1, boxes2):
|
||||
""" Compute IOU between all boxes from ``boxes1`` with all boxes from ``boxes2``.
|
||||
|
||||
Args:
|
||||
boxes1 (np.array): List of bounding boxes
|
||||
boxes2 (np.array): List of bounding boxes
|
||||
|
||||
Note:
|
||||
List format: [[xc, yc, w, h],...]
|
||||
"""
|
||||
b1x1, b1y1 = np.split((boxes1[:, :2] - (boxes1[:, 2:4] / 2)), 2, axis=1)
|
||||
b1x2, b1y2 = np.split((boxes1[:, :2] + (boxes1[:, 2:4] / 2)), 2, axis=1)
|
||||
b2x1, b2y1 = np.split((boxes2[:, :2] - (boxes2[:, 2:4] / 2)), 2, axis=1)
|
||||
b2x2, b2y2 = np.split((boxes2[:, :2] + (boxes2[:, 2:4] / 2)), 2, axis=1)
|
||||
|
||||
dx = np.minimum(b1x2, b2x2.transpose()) - np.maximum(b1x1, b2x1.transpose())
|
||||
dx = np.maximum(dx, 0)
|
||||
dy = np.minimum(b1y2, b2y2.transpose()) - np.maximum(b1y1, b2y1.transpose())
|
||||
dy = np.maximum(dy, 0)
|
||||
intersections = dx * dy
|
||||
|
||||
areas1 = (b1x2 - b1x1) * (b1y2 - b1y1)
|
||||
areas2 = (b2x2 - b2x1) * (b2y2 - b2y1)
|
||||
unions = (areas1 + areas2.transpose()) - intersections
|
||||
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_targets_brambox(img, anno, reduction, img_shape_para, anchors_mask, anchors):
|
||||
"""
|
||||
Compare prediction boxes and ground truths, convert ground truths to network output tensors
|
||||
"""
|
||||
ground_truth = anno
|
||||
img_shape = img.shape
|
||||
n_h = int(img_shape[1] / img_shape_para) # height
|
||||
n_w = int(img_shape[2] / img_shape_para) # width
|
||||
anchors_ori = np.array(anchors) / reduction
|
||||
num_anchor = len(anchors_mask)
|
||||
conf_pos_mask = np.zeros((num_anchor, n_h * n_w), dtype=np.float32) # pos mask
|
||||
conf_neg_mask = np.ones((num_anchor, n_h * n_w), dtype=np.float32) # neg mask
|
||||
|
||||
# coordination mask and classification mask
|
||||
coord_mask = np.zeros((num_anchor, 1, n_h * n_w), dtype=np.float32) # coord mask
|
||||
cls_mask = np.zeros((num_anchor, n_h * n_w), dtype=np.int)
|
||||
|
||||
# for target coordination confidence classification
|
||||
t_coord = np.zeros((num_anchor, 4, n_h * n_w), dtype=np.float32)
|
||||
t_conf = np.zeros((num_anchor, n_h * n_w), dtype=np.float32)
|
||||
t_cls = np.zeros((num_anchor, n_h * n_w), dtype=np.float32)
|
||||
|
||||
gt_list = None
|
||||
is_empty_or_negative, filtered_ground_truth = check_gt_negative_or_empty(ground_truth)
|
||||
if is_empty_or_negative:
|
||||
gt_np = np.zeros((len(ground_truth), 4), dtype=np.float32)
|
||||
gt_temp = gt_np[:]
|
||||
gt_list = gt_temp
|
||||
# continue
|
||||
return coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list
|
||||
# Build up tensors
|
||||
anchors = np.concatenate([np.zeros_like(anchors_ori), anchors_ori], axis=1)
|
||||
gt = np.zeros((len(filtered_ground_truth), 4), dtype=np.float32)
|
||||
gt_np = np.zeros((len(ground_truth), 4), dtype=np.float32)
|
||||
for i, annotation in enumerate(filtered_ground_truth):
|
||||
# gt x y x h->x_c y_c w h
|
||||
# reduction for remap the gt to the feature
|
||||
gt[i, 0] = (annotation[1] + annotation[3] / 2) / reduction
|
||||
gt[i, 1] = (annotation[2] + annotation[4] / 2) / reduction
|
||||
gt[i, 2] = annotation[3] / reduction
|
||||
gt[i, 3] = annotation[4] / reduction
|
||||
|
||||
gt_np[i, 0] = annotation[1] / reduction
|
||||
gt_np[i, 1] = annotation[2] / reduction
|
||||
gt_np[i, 2] = (annotation[1] + annotation[3]) / reduction
|
||||
gt_np[i, 3] = (annotation[2] + annotation[4]) / reduction
|
||||
gt_temp = gt_np[:]
|
||||
gt_list = gt_temp
|
||||
|
||||
# Find best anchor for each gt
|
||||
|
||||
gt_wh = np.copy(gt)
|
||||
gt_wh[:, :2] = 0
|
||||
iou_gt_anchors = bbox_ious_numpy(gt_wh, anchors)
|
||||
best_anchors = np.argmax(iou_gt_anchors, axis=1)
|
||||
# Set masks and target values for each gt
|
||||
|
||||
for i, annotation in enumerate(filtered_ground_truth):
|
||||
annotation_ignore = annotation[5]
|
||||
annotation_width = annotation[3]
|
||||
annotation_height = annotation[4]
|
||||
annotation_class_id = annotation[0]
|
||||
|
||||
gi = min(n_w - 1, max(0, int(gt[i, 0])))
|
||||
gj = min(n_h - 1, max(0, int(gt[i, 1])))
|
||||
cur_n = best_anchors[i] # best anchors for current ground truth
|
||||
|
||||
if cur_n in anchors_mask:
|
||||
best_n = np.where(np.array(anchors_mask) == cur_n)[0][0]
|
||||
else:
|
||||
continue
|
||||
|
||||
if annotation_ignore:
|
||||
# current annotation is ignore for difficult
|
||||
conf_pos_mask[best_n][gj * n_w + gi] = 0
|
||||
conf_neg_mask[best_n][gj * n_w + gi] = 0
|
||||
else:
|
||||
coord_mask[best_n][0][gj * n_w + gi] = 2 - annotation_width * annotation_height / \
|
||||
(n_w * n_h * reduction * reduction)
|
||||
cls_mask[best_n][gj * n_w + gi] = 1
|
||||
conf_pos_mask[best_n][gj * n_w + gi] = 1
|
||||
conf_neg_mask[best_n][gj * n_w + gi] = 0
|
||||
t_coord[best_n][0][gj * n_w + gi] = gt[i, 0] - gi
|
||||
t_coord[best_n][1][gj * n_w + gi] = gt[i, 1] - gj
|
||||
t_coord[best_n][2][gj * n_w + gi] = np.log(gt[i, 2] / anchors[cur_n, 2])
|
||||
t_coord[best_n][3][gj * n_w + gi] = np.log(gt[i, 3] / anchors[cur_n, 3])
|
||||
t_conf[best_n][gj * n_w + gi] = 1
|
||||
t_cls[best_n][gj * n_w + gi] = annotation_class_id
|
||||
|
||||
return coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list
|
||||
|
||||
|
||||
def preprocess_fn(image, annotation):
|
||||
'''preprocess_fn'''
|
||||
jitter = config.jitter
|
||||
flip = config.flip
|
||||
hue = config.hue
|
||||
sat = config.sat
|
||||
val = config.val
|
||||
size = config.input_shape
|
||||
max_anno_count = 200
|
||||
reduction_0 = 64.0
|
||||
reduction_1 = 32.0
|
||||
reduction_2 = 16.0
|
||||
anchors = config.anchors
|
||||
anchors_mask = config.anchors_mask
|
||||
|
||||
decode = P.Decode()
|
||||
random_crop_letter_box_op = RandomCropLetterbox(jitter=jitter, input_dim=size)
|
||||
random_flip_op = RandomFlip(flip)
|
||||
hsv_shift_op = HSVShift(hue, sat, val)
|
||||
to_tensor = P.ToTensor()
|
||||
|
||||
img_pil = decode(image)
|
||||
input_data = img_pil, annotation
|
||||
input_data = random_crop_letter_box_op(*input_data)
|
||||
input_data = random_flip_op(*input_data)
|
||||
input_data = hsv_shift_op(*input_data)
|
||||
image_arr = to_tensor(input_data[0])
|
||||
ret_img = image_arr
|
||||
ret_anno = input_data[1]
|
||||
|
||||
anno_count = ret_anno.shape[0]
|
||||
|
||||
if anno_count < max_anno_count:
|
||||
ret_anno = np.concatenate((ret_anno, np.zeros((max_anno_count - anno_count, 6), dtype=float)), axis=0)
|
||||
else:
|
||||
ret_anno = ret_anno[:max_anno_count]
|
||||
|
||||
ret_img = np.array(ret_img)
|
||||
ret_anno = np.array(ret_anno)
|
||||
|
||||
coord_mask_0, conf_pos_mask_0, conf_neg_mask_0, cls_mask_0, t_coord_0, t_conf_0, t_cls_0, gt_list_0 = \
|
||||
build_targets_brambox(ret_img, ret_anno, reduction_0, int(reduction_0), anchors_mask[0], anchors)
|
||||
coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1 = \
|
||||
build_targets_brambox(ret_img, ret_anno, reduction_1, int(reduction_1), anchors_mask[1], anchors)
|
||||
coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, t_cls_2, gt_list_2 = \
|
||||
build_targets_brambox(ret_img, ret_anno, reduction_2, int(reduction_2), anchors_mask[2], anchors)
|
||||
|
||||
return ret_img, ret_anno, coord_mask_0, conf_pos_mask_0, conf_neg_mask_0, cls_mask_0, t_coord_0, t_conf_0,\
|
||||
t_cls_0, gt_list_0, coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1,\
|
||||
t_cls_1, gt_list_1, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, \
|
||||
t_cls_2, gt_list_2
|
||||
|
||||
|
||||
compose_map_func = (preprocess_fn)
|
@ -0,0 +1,174 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert VOC format dataset to mindrecord for evaluating Face detection."""
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
dataset_root_list = ["Your_VOC_dataset_path1",
|
||||
"Your_VOC_dataset_path2",
|
||||
"Your_VOC_dataset_pathN",
|
||||
]
|
||||
|
||||
mindrecord_file_name = "Your_output_path/data.mindrecord"
|
||||
|
||||
mindrecord_num = 8
|
||||
is_train = False
|
||||
class_indexing_1 = {'face': 0}
|
||||
|
||||
|
||||
def prepare_file_paths():
|
||||
'''prepare_file_paths'''
|
||||
image_files = []
|
||||
anno_files = []
|
||||
image_names = []
|
||||
for dataset_root in dataset_root_list:
|
||||
if not os.path.isdir(dataset_root):
|
||||
raise ValueError("dataset root is unvalid!")
|
||||
anno_dir = os.path.join(dataset_root, "Annotations")
|
||||
image_dir = os.path.join(dataset_root, "JPEGImages")
|
||||
if is_train:
|
||||
valid_txt = os.path.join(dataset_root, "ImageSets/Main/train.txt")
|
||||
else:
|
||||
valid_txt = os.path.join(dataset_root, "ImageSets/Main/test.txt")
|
||||
|
||||
ret_image_files, ret_anno_files, ret_image_names = filter_valid_files_by_txt(image_dir, anno_dir, valid_txt)
|
||||
image_files.extend(ret_image_files)
|
||||
anno_files.extend(ret_anno_files)
|
||||
image_names.extend(ret_image_names)
|
||||
return image_files, anno_files, image_names
|
||||
|
||||
|
||||
def filter_valid_files_by_txt(image_dir, anno_dir, valid_txt):
|
||||
'''filter_valid_files_by_txt'''
|
||||
with open(valid_txt, "r") as txt:
|
||||
valid_names = txt.readlines()
|
||||
image_files = []
|
||||
anno_files = []
|
||||
image_names = []
|
||||
for name in valid_names:
|
||||
strip_name = name.strip("\n")
|
||||
anno_joint_path = os.path.join(anno_dir, strip_name + ".xml")
|
||||
if os.path.isfile(anno_joint_path):
|
||||
image_joint_path = os.path.join(image_dir, strip_name + ".jpg")
|
||||
image_name = image_joint_path.split('/')[-1].replace('.jpg', '')
|
||||
if os.path.isfile(image_joint_path):
|
||||
image_files.append(image_joint_path)
|
||||
anno_files.append(anno_joint_path)
|
||||
image_names.append(image_name)
|
||||
continue
|
||||
image_joint_path = os.path.join(image_dir, strip_name + ".png")
|
||||
image_name = image_joint_path.split('/')[-1].replace('.png', '')
|
||||
if os.path.isfile(image_joint_path):
|
||||
image_files.append(image_joint_path)
|
||||
anno_files.append(anno_joint_path)
|
||||
image_names.append(image_name)
|
||||
return image_files, anno_files, image_names
|
||||
|
||||
|
||||
def deserialize(member, class_indexing):
|
||||
'''deserialize'''
|
||||
class_name = member[0].text
|
||||
if class_name in class_indexing:
|
||||
class_num = class_indexing[class_name]
|
||||
else:
|
||||
return None
|
||||
bnx = member.find('bndbox')
|
||||
box_x_min = float(bnx.find('xmin').text)
|
||||
box_y_min = float(bnx.find('ymin').text)
|
||||
box_x_max = float(bnx.find('xmax').text)
|
||||
box_y_max = float(bnx.find('ymax').text)
|
||||
width = float(box_x_max - box_x_min + 1)
|
||||
height = float(box_y_max - box_y_min + 1)
|
||||
|
||||
try:
|
||||
ignore = float(member.find('ignore').text)
|
||||
except ValueError:
|
||||
ignore = 0.0
|
||||
return [class_num, box_x_min, box_y_min, width, height, ignore]
|
||||
|
||||
|
||||
def get_data(image_file, anno_file, image_name):
|
||||
'''get_data'''
|
||||
count = 0
|
||||
annotation = []
|
||||
tree = ET.parse(anno_file)
|
||||
root = tree.getroot()
|
||||
|
||||
with Image.open(image_file) as fd:
|
||||
orig_width, orig_height = fd.size
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
for member in root.findall('object'):
|
||||
anno = deserialize(member, class_indexing_1)
|
||||
if anno is not None:
|
||||
annotation.extend(anno)
|
||||
count += 1
|
||||
|
||||
for member in root.findall('Object'):
|
||||
anno = deserialize(member, class_indexing_1)
|
||||
if anno is not None:
|
||||
annotation.extend(anno)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
annotation = np.array([[-1, -1, -1, -1, -1, -1]], dtype='float64')
|
||||
count = 1
|
||||
|
||||
data = {
|
||||
"image": img,
|
||||
"annotation": np.array(annotation, dtype='float64'),
|
||||
"image_name": image_name,
|
||||
"image_size": np.array([orig_width, orig_height], dtype='int32')
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def convert_yolo_data_to_mindrecord():
|
||||
'''convert_yolo_data_to_mindrecord'''
|
||||
|
||||
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||
yolo_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "float64", "shape": [-1, 6]},
|
||||
"image_name": {"type": "string"},
|
||||
"image_size": {"type": "int32", "shape": [-1, 2]}
|
||||
}
|
||||
|
||||
print('Loading eval data...')
|
||||
image_files, anno_files, image_names = prepare_file_paths()
|
||||
dataset_size = len(anno_files)
|
||||
assert dataset_size == len(image_files)
|
||||
assert dataset_size == len(image_names)
|
||||
logger.info("#size of dataset: {}".format(dataset_size))
|
||||
data = []
|
||||
for i in range(dataset_size):
|
||||
data.append(get_data(image_files[i], anno_files[i], image_names[i]))
|
||||
|
||||
print('Writing eval data to mindrecord...')
|
||||
writer.add_schema(yolo_json, "yolo_json")
|
||||
if data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_yolo_data_to_mindrecord()
|
@ -0,0 +1,157 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Convert VOC format dataset to mindrecord for training Face detection."""
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
dataset_root_list = ["Your_VOC_dataset_path1",
|
||||
"Your_VOC_dataset_path2",
|
||||
"Your_VOC_dataset_pathN",
|
||||
]
|
||||
|
||||
mindrecord_file_name = "Your_output_path/data.mindrecord"
|
||||
|
||||
mindrecord_num = 8
|
||||
is_train = True
|
||||
class_indexing_1 = {'face': 0}
|
||||
|
||||
|
||||
def prepare_file_paths():
|
||||
'''prepare_file_paths'''
|
||||
image_files = []
|
||||
anno_files = []
|
||||
for dataset_root in dataset_root_list:
|
||||
if not os.path.isdir(dataset_root):
|
||||
raise ValueError("dataset root is unvalid!")
|
||||
anno_dir = os.path.join(dataset_root, "Annotations")
|
||||
image_dir = os.path.join(dataset_root, "JPEGImages")
|
||||
if is_train:
|
||||
valid_txt = os.path.join(dataset_root, "ImageSets/Main/train.txt")
|
||||
else:
|
||||
valid_txt = os.path.join(dataset_root, "ImageSets/Main/test.txt")
|
||||
|
||||
ret_image_files, ret_anno_files = filter_valid_files_by_txt(image_dir, anno_dir, valid_txt)
|
||||
image_files.extend(ret_image_files)
|
||||
anno_files.extend(ret_anno_files)
|
||||
return image_files, anno_files
|
||||
|
||||
|
||||
def filter_valid_files_by_txt(image_dir, anno_dir, valid_txt):
|
||||
'''filter_valid_files_by_txt'''
|
||||
with open(valid_txt, "r") as txt:
|
||||
valid_names = txt.readlines()
|
||||
image_files = []
|
||||
anno_files = []
|
||||
for name in valid_names:
|
||||
strip_name = name.strip("\n")
|
||||
anno_joint_path = os.path.join(anno_dir, strip_name + ".xml")
|
||||
if os.path.isfile(anno_joint_path):
|
||||
image_joint_path = os.path.join(image_dir, strip_name + ".jpg")
|
||||
if os.path.isfile(image_joint_path):
|
||||
image_files.append(image_joint_path)
|
||||
anno_files.append(anno_joint_path)
|
||||
continue
|
||||
image_joint_path = os.path.join(image_dir, strip_name + ".png")
|
||||
if os.path.isfile(image_joint_path):
|
||||
image_files.append(image_joint_path)
|
||||
anno_files.append(anno_joint_path)
|
||||
return image_files, anno_files
|
||||
|
||||
|
||||
def deserialize(member, class_indexing):
|
||||
'''deserialize'''
|
||||
class_name = member[0].text
|
||||
if class_name in class_indexing:
|
||||
class_num = class_indexing[class_name]
|
||||
else:
|
||||
return None
|
||||
bnx = member.find('bndbox')
|
||||
box_x_min = float(bnx.find('xmin').text)
|
||||
box_y_min = float(bnx.find('ymin').text)
|
||||
box_x_max = float(bnx.find('xmax').text)
|
||||
box_y_max = float(bnx.find('ymax').text)
|
||||
width = float(box_x_max - box_x_min + 1)
|
||||
height = float(box_y_max - box_y_min + 1)
|
||||
|
||||
try:
|
||||
ignore = float(member.find('ignore').text)
|
||||
except ValueError:
|
||||
ignore = 0.0
|
||||
return [class_num, box_x_min, box_y_min, width, height, ignore]
|
||||
|
||||
|
||||
def get_data(image_file, anno_file):
|
||||
'''get_data'''
|
||||
count = 0
|
||||
annotation = []
|
||||
tree = ET.parse(anno_file)
|
||||
root = tree.getroot()
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
for member in root.findall('object'):
|
||||
anno = deserialize(member, class_indexing_1)
|
||||
if anno is not None:
|
||||
annotation.extend(anno)
|
||||
count += 1
|
||||
|
||||
for member in root.findall('Object'):
|
||||
anno = deserialize(member, class_indexing_1)
|
||||
if anno is not None:
|
||||
annotation.extend(anno)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
annotation = np.array([[-1, -1, -1, -1, -1, -1]], dtype='float64')
|
||||
count = 1
|
||||
data = {
|
||||
"image": img,
|
||||
"annotation": np.array(annotation, dtype='float64')
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def convert_yolo_data_to_mindrecord():
|
||||
'''convert_yolo_data_to_mindrecord'''
|
||||
|
||||
writer = FileWriter(mindrecord_file_name, mindrecord_num)
|
||||
yolo_json = {
|
||||
"image": {"type": "bytes"},
|
||||
"annotation": {"type": "float64", "shape": [-1, 6]}
|
||||
}
|
||||
|
||||
print('Loading train data...')
|
||||
image_files, anno_files = prepare_file_paths()
|
||||
dataset_size = len(anno_files)
|
||||
assert dataset_size == len(image_files)
|
||||
logger.info("#size of dataset: {}".format(dataset_size))
|
||||
data = []
|
||||
for i in range(dataset_size):
|
||||
data.append(get_data(image_files[i], anno_files[i]))
|
||||
|
||||
print('Writing train data to mindrecord...')
|
||||
writer.add_schema(yolo_json, "yolo_json")
|
||||
if data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_yolo_data_to_mindrecord()
|
@ -0,0 +1,154 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Add VOC format dataset to an existed mindrecord for training Face detection."""
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
dataset_root_list = ["Your_VOC_dataset_path1",
|
||||
"Your_VOC_dataset_path2",
|
||||
"Your_VOC_dataset_pathN",
|
||||
]
|
||||
|
||||
mindrecord_file_name = "Your_previous_output_path/data.mindrecord0"
|
||||
|
||||
mindrecord_num = 8
|
||||
is_train = True
|
||||
class_indexing_1 = {'face': 0}
|
||||
|
||||
|
||||
def prepare_file_paths():
|
||||
'''prepare file paths'''
|
||||
image_files = []
|
||||
anno_files = []
|
||||
for dataset_root in dataset_root_list:
|
||||
if not os.path.isdir(dataset_root):
|
||||
raise ValueError("dataset root is unvalid!")
|
||||
anno_dir = os.path.join(dataset_root, "Annotations")
|
||||
image_dir = os.path.join(dataset_root, "JPEGImages")
|
||||
if is_train:
|
||||
valid_txt = os.path.join(dataset_root, "ImageSets/Main/train.txt")
|
||||
else:
|
||||
valid_txt = os.path.join(dataset_root, "ImageSets/Main/test.txt")
|
||||
|
||||
ret_image_files, ret_anno_files = filter_valid_files_by_txt(image_dir, anno_dir, valid_txt)
|
||||
image_files.extend(ret_image_files)
|
||||
anno_files.extend(ret_anno_files)
|
||||
return image_files, anno_files
|
||||
|
||||
|
||||
def filter_valid_files_by_txt(image_dir, anno_dir, valid_txt):
|
||||
'''filter valid files by txt'''
|
||||
with open(valid_txt, "r") as txt:
|
||||
valid_names = txt.readlines()
|
||||
image_files = []
|
||||
anno_files = []
|
||||
for name in valid_names:
|
||||
strip_name = name.strip("\n")
|
||||
anno_joint_path = os.path.join(anno_dir, strip_name + ".xml")
|
||||
if os.path.isfile(anno_joint_path):
|
||||
image_joint_path = os.path.join(image_dir, strip_name + ".jpg")
|
||||
if os.path.isfile(image_joint_path):
|
||||
image_files.append(image_joint_path)
|
||||
anno_files.append(anno_joint_path)
|
||||
continue
|
||||
image_joint_path = os.path.join(image_dir, strip_name + ".png")
|
||||
if os.path.isfile(image_joint_path):
|
||||
image_files.append(image_joint_path)
|
||||
anno_files.append(anno_joint_path)
|
||||
return image_files, anno_files
|
||||
|
||||
|
||||
def deserialize(member, class_indexing):
|
||||
'''deserialize'''
|
||||
class_name = member[0].text
|
||||
if class_name in class_indexing:
|
||||
class_num = class_indexing[class_name]
|
||||
else:
|
||||
return None
|
||||
bnx = member.find('bndbox')
|
||||
box_x_min = float(bnx.find('xmin').text)
|
||||
box_y_min = float(bnx.find('ymin').text)
|
||||
box_x_max = float(bnx.find('xmax').text)
|
||||
box_y_max = float(bnx.find('ymax').text)
|
||||
width = float(box_x_max - box_x_min + 1)
|
||||
height = float(box_y_max - box_y_min + 1)
|
||||
|
||||
try:
|
||||
ignore = float(member.find('ignore').text)
|
||||
except ValueError:
|
||||
ignore = 0.0
|
||||
return [class_num, box_x_min, box_y_min, width, height, ignore]
|
||||
|
||||
|
||||
def get_data(image_file, anno_file):
|
||||
'''get_data'''
|
||||
count = 0
|
||||
annotation = []
|
||||
tree = ET.parse(anno_file)
|
||||
root = tree.getroot()
|
||||
|
||||
with open(image_file, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
for member in root.findall('object'):
|
||||
anno = deserialize(member, class_indexing_1)
|
||||
if anno is not None:
|
||||
annotation.extend(anno)
|
||||
count += 1
|
||||
|
||||
for member in root.findall('Object'):
|
||||
anno = deserialize(member, class_indexing_1)
|
||||
if anno is not None:
|
||||
annotation.extend(anno)
|
||||
count += 1
|
||||
|
||||
if count == 0:
|
||||
annotation = np.array([[-1, -1, -1, -1, -1, -1]], dtype='float64')
|
||||
count = 1
|
||||
data = {
|
||||
"image": img,
|
||||
"annotation": np.array(annotation, dtype='float64')
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def convert_yolo_data_to_mindrecord():
|
||||
'''convert_yolo_data_to_mindrecord'''
|
||||
|
||||
print('Loading mindrecord...')
|
||||
writer = FileWriter.open_for_append(mindrecord_file_name,)
|
||||
|
||||
print('Loading train data...')
|
||||
image_files, anno_files = prepare_file_paths()
|
||||
dataset_size = len(anno_files)
|
||||
assert dataset_size == len(image_files)
|
||||
logger.info("#size of dataset: {}".format(dataset_size))
|
||||
data = []
|
||||
for i in range(dataset_size):
|
||||
data.append(get_data(image_files[i], anno_files[i]))
|
||||
|
||||
print('Writing train data to mindrecord...')
|
||||
if data is None:
|
||||
raise ValueError("None needs writing to mindrecord.")
|
||||
writer.write_raw_data(data)
|
||||
writer.commit()
|
||||
|
||||
|
||||
convert_yolo_data_to_mindrecord()
|
@ -0,0 +1,105 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom logger."""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
logger_name_1 = 'yolov3_face_detection'
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
'''LOGGER'''
|
||||
def __init__(self, logger_name):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
self.local_rank = 0
|
||||
|
||||
def setup_logging_file(self, log_dir, local_rank=0):
|
||||
'''setup_logging_file'''
|
||||
self.local_rank = local_rank
|
||||
if self.local_rank == 0:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '.log'
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.local_rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
logger = LOGGER(logger_name_1)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f', tb_writer=None):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
self.tb_writer = tb_writer
|
||||
self.cur_step = 1
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
|
||||
self.cur_step += 1
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name}:{avg' + self.fmt + '}'
|
||||
return fmtstr.format(**self.__dict__)
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Face detection learning rate scheduler."""
|
||||
from collections import Counter
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
|
||||
def warmup_step(args, gamma=0.1, lr_scale=1.0):
|
||||
'''warmup_step'''
|
||||
base_lr = args.lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(args.max_epoch * args.steps_per_epoch)
|
||||
warmup_steps = int(args.warmup_epochs * args.steps_per_epoch)
|
||||
milestones = args.lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone*args.steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_learning_rate(
|
||||
i, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr_scale * lr * gamma**milestones_steps_counter[i]
|
||||
print('i:{} lr:{}'.format(i, lr))
|
||||
lr_each_step.append(lr)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_step_new(args, lr_scale=1.0):
|
||||
'''warmup_step_new'''
|
||||
warmup_lr = args.warmup_lr / args.batch_size
|
||||
lr_rates = [lr_rate / args.batch_size * lr_scale for lr_rate in args.lr_rates]
|
||||
total_steps = int(args.max_epoch * args.steps_per_epoch)
|
||||
lr_steps = args.lr_steps
|
||||
warmup_steps = lr_steps[0]
|
||||
lr_left = 0
|
||||
print('real warmup_lr', warmup_lr)
|
||||
print('real lr_rates', lr_rates)
|
||||
if args.max_epoch * args.steps_per_epoch > lr_steps[-1]:
|
||||
lr_steps.append(args.max_epoch * args.steps_per_epoch)
|
||||
lr_rates.append(lr_rates[-1])
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = warmup_lr
|
||||
elif i < lr_steps[lr_left]:
|
||||
lr = lr_rates[lr_left]
|
||||
else:
|
||||
lr_left = lr_left + 1
|
||||
lr_each_step.append(lr)
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0):
|
||||
'''warmup_cosine_annealing_lr'''
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_learning_rate(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue