parent
67c3fded73
commit
6afa96de4d
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,154 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
CenterNet evaluation script.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
import json
|
||||
import argparse
|
||||
import cv2
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from mindspore import context
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore.log as logger
|
||||
from src import COCOHP, CenterNetMultiPoseEval
|
||||
from src import convert_eval_format, post_process, merge_outputs
|
||||
from src import visual_image
|
||||
from src.config import dataset_config, net_config, eval_config
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='CenterNet evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the data_dir "
|
||||
"and the relative path in anno_path")
|
||||
parser.add_argument("--run_mode", type=str, default="test", help="test or validation, default is test.")
|
||||
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
|
||||
parser.add_argument("--enable_eval", type=str, default="true", help="Wether evaluate accuracy after prediction")
|
||||
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
def predict():
|
||||
'''
|
||||
Predict function
|
||||
'''
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
|
||||
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
|
||||
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode=args_opt.run_mode)
|
||||
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,
|
||||
keep_res=eval_config.keep_res, flip_test=eval_config.flip_test)
|
||||
dataset = coco.create_eval_dataset()
|
||||
|
||||
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
|
||||
net_for_eval.set_train(False)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
load_param_into_net(net_for_eval, param_dict)
|
||||
|
||||
# save results
|
||||
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
if args_opt.visual_image == "true":
|
||||
save_pred_image_path = os.path.join(save_path, "pred_image")
|
||||
if not os.path.exists(save_pred_image_path):
|
||||
os.makedirs(save_pred_image_path)
|
||||
save_gt_image_path = os.path.join(save_path, "gt_image")
|
||||
if not os.path.exists(save_gt_image_path):
|
||||
os.makedirs(save_gt_image_path)
|
||||
|
||||
total_nums = dataset.get_dataset_size()
|
||||
print("\n========================================\n")
|
||||
print("Total images num: ", total_nums)
|
||||
print("Processing, please wait a moment.")
|
||||
|
||||
pred_annos = {"images": [], "annotations": []}
|
||||
|
||||
index = 0
|
||||
for data in dataset.create_dict_iterator(num_epochs=1):
|
||||
index += 1
|
||||
image = data['image']
|
||||
image_id = data['image_id'].asnumpy().reshape((-1))[0]
|
||||
|
||||
# run prediction
|
||||
start = time.time()
|
||||
detections = []
|
||||
for scale in eval_config.multi_scales:
|
||||
images, meta = coco.pre_process_for_test(image.asnumpy(), image_id, scale)
|
||||
detection = net_for_eval(Tensor(images))
|
||||
dets = post_process(detection.asnumpy(), meta, scale)
|
||||
detections.append(dets)
|
||||
end = time.time()
|
||||
print("Image {}/{} id: {} cost time {} ms".format(index, total_nums, image_id, (end - start) * 1000.))
|
||||
|
||||
# post-process
|
||||
soft_nms = eval_config.soft_nms or len(eval_config.multi_scales) > 0
|
||||
detections = merge_outputs(detections, soft_nms)
|
||||
|
||||
# get prediction result
|
||||
pred_json = convert_eval_format(detections, image_id)
|
||||
gt_image_info = coco.coco.loadImgs([image_id])
|
||||
|
||||
for image_info in pred_json["images"]:
|
||||
pred_annos["images"].append(image_info)
|
||||
for image_anno in pred_json["annotations"]:
|
||||
pred_annos["annotations"].append(image_anno)
|
||||
if args_opt.visual_image == "true":
|
||||
img_file = os.path.join(coco.image_path, gt_image_info[0]['file_name'])
|
||||
gt_image = cv2.imread(img_file)
|
||||
if args_opt.run_mode != "test":
|
||||
annos = coco.coco.loadAnns(coco.anns[image_id])
|
||||
visual_image(copy.deepcopy(gt_image), annos, save_gt_image_path)
|
||||
anno = copy.deepcopy(pred_json["annotations"])
|
||||
visual_image(gt_image, anno, save_pred_image_path, score_threshold=eval_config.score_thresh)
|
||||
|
||||
# save results
|
||||
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(args_opt.run_mode)
|
||||
json.dump(pred_annos, open(pred_anno_file, 'w'))
|
||||
pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(args_opt.run_mode)
|
||||
json.dump(pred_annos["annotations"], open(pred_res_file, 'w'))
|
||||
|
||||
if args_opt.run_mode != "test" and args_opt.enable_eval:
|
||||
run_eval(coco.annot_path, pred_res_file)
|
||||
|
||||
|
||||
def run_eval(gt_anno, pred_anno):
|
||||
"""evaluation by coco api"""
|
||||
coco = COCO(gt_anno)
|
||||
coco_dets = coco.loadRes(pred_anno)
|
||||
coco_eval = COCOeval(coco, coco_dets, "keypoints")
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
coco_eval = COCOeval(coco, coco_dets, "bbox")
|
||||
coco_eval.evaluate()
|
||||
coco_eval.accumulate()
|
||||
coco_eval.summarize()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
@ -0,0 +1,44 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Export CenterNet mindir model.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src import CenterNetMultiPoseEval
|
||||
from src.config import net_config, eval_config, export_config
|
||||
|
||||
parser = argparse.ArgumentParser(description='centernet export')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
|
||||
net.set_train(False)
|
||||
|
||||
param_dict = load_checkpoint(export_config.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
input_shape = [1, 3, export_config.input_res[0], export_config.input_res[1]]
|
||||
input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32))
|
||||
|
||||
export(net, input_data, file_name=export_config.export_name, file_format=export_config.export_format)
|
@ -0,0 +1,51 @@
|
||||
# Run distribute train
|
||||
|
||||
## description
|
||||
|
||||
The number of Ascend accelerators can be automatically allocated based on the device_num set in hccl config file, You don not need to specify that.
|
||||
|
||||
## how to use
|
||||
|
||||
For example, if we want to generate the launch command of the distributed training of CenterNet model on Ascend accelerators, we can run the following command in `/centernet/` dir:
|
||||
|
||||
```python
|
||||
python ./scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py --run_script_dir ./train.py --hyper_parameter_config_dir ./scripts/ascend_distributed_launcher/hyper_parameter_config.ini --data_dir /path/dataset/ --mindrecord_dir /path/mindrecord_dataset/ --hccl_config_dir model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
|
||||
```
|
||||
|
||||
output:
|
||||
|
||||
```text
|
||||
hccl_config_dir: model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
|
||||
the number of logical core: 192
|
||||
avg_core_per_rank: 96
|
||||
rank_size: 2
|
||||
|
||||
start training for rank 0, device 5:
|
||||
rank_id: 0
|
||||
device_id: 5
|
||||
core nums: 0-95
|
||||
epoch_size: 350
|
||||
data_dir: /path/dataset/
|
||||
mindrecord_dir: /path/mindrecord_dataset/
|
||||
log file dir: ./LOG5/training_log.txt
|
||||
|
||||
start training for rank 1, device 6:
|
||||
rank_id: 1
|
||||
device_id: 6
|
||||
core nums: 96-191
|
||||
epoch_size: 350
|
||||
data_dir: /path/dataset/
|
||||
mindrecord_dir: /path/mindrecord_dataset/
|
||||
log file dir: ./LOG6/training_log.txt
|
||||
```
|
||||
|
||||
## Note
|
||||
|
||||
1. Note that `hccl_2p_56_x.x.x.x.json` can use [hccl_tools.py](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate.
|
||||
|
||||
2. For hyper parameter, please note that you should customize the scripts `hyper_parameter_config.ini`. Please note that these two hyper parameters are not allowed to be configured here:
|
||||
- device_id
|
||||
- device_num
|
||||
- data_dir
|
||||
|
||||
3. For Other Model, please note that you should customize the option `run_script` and Corresponding `hyper_parameter_config.ini`.
|
@ -0,0 +1,170 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""distribute pretrain script"""
|
||||
import os
|
||||
import json
|
||||
import configparser
|
||||
import multiprocessing
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
parse args .
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
args.
|
||||
|
||||
Examples:
|
||||
>>> parse_args()
|
||||
"""
|
||||
parser = ArgumentParser(description="mindspore distributed training")
|
||||
|
||||
parser.add_argument("--run_script_dir", type=str, default="",
|
||||
help="Run script path, it is better to use absolute path")
|
||||
parser.add_argument("--hyper_parameter_config_dir", type=str, default="",
|
||||
help="Hyper Parameter config path, it is better to use absolute path")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
|
||||
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by "
|
||||
"data_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
|
||||
"rather than data_dir and anno_path. Default is ./Mindrecord_train")
|
||||
parser.add_argument("--data_dir", type=str, default="",
|
||||
help="Data path, it is better to use absolute path")
|
||||
parser.add_argument("--hccl_config_dir", type=str, default="",
|
||||
help="Hccl config path, it is better to use absolute path")
|
||||
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
|
||||
help="Path of the generated cmd file.")
|
||||
parser.add_argument("--hccl_time_out", type=int, default=120,
|
||||
help="Seconds to determine the hccl time out,"
|
||||
"default: 120, which is the same as hccl default config")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def append_cmd(cmd, s):
|
||||
cmd += s
|
||||
cmd += "\n"
|
||||
return cmd
|
||||
|
||||
def append_cmd_env(cmd, key, value):
|
||||
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
|
||||
|
||||
def distribute_train():
|
||||
"""
|
||||
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
|
||||
based on the device_num set in hccl config file, You don not need to specify that.
|
||||
"""
|
||||
cmd = ""
|
||||
print("start", __file__)
|
||||
args = parse_args()
|
||||
|
||||
run_script = args.run_script_dir
|
||||
data_dir = args.data_dir
|
||||
mindrecord_dir = args.mindrecord_dir
|
||||
cf = configparser.ConfigParser()
|
||||
cf.read(args.hyper_parameter_config_dir)
|
||||
cfg = dict(cf.items("config"))
|
||||
|
||||
print("hccl_config_dir:", args.hccl_config_dir)
|
||||
print("hccl_time_out:", args.hccl_time_out)
|
||||
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
|
||||
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
|
||||
|
||||
cores = multiprocessing.cpu_count()
|
||||
print("the number of logical core:", cores)
|
||||
|
||||
# get device_ips
|
||||
device_ips = {}
|
||||
with open('/etc/hccn.conf', 'r') as fin:
|
||||
for hccn_item in fin.readlines():
|
||||
if hccn_item.strip().startswith('address_'):
|
||||
device_id, device_ip = hccn_item.split('=')
|
||||
device_id = device_id.split('_')[1]
|
||||
device_ips[device_id] = device_ip.strip()
|
||||
|
||||
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
|
||||
hccl_config = json.loads(fin.read())
|
||||
rank_size = 0
|
||||
for server in hccl_config["server_list"]:
|
||||
rank_size += len(server["device"])
|
||||
if server["device"][0]["device_ip"] in device_ips.values():
|
||||
this_server = server
|
||||
|
||||
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
|
||||
print("total rank size:", rank_size)
|
||||
print("this server rank size:", len(this_server["device"]))
|
||||
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
|
||||
core_gap = avg_core_per_rank - 1
|
||||
print("avg_core_per_rank:", avg_core_per_rank)
|
||||
|
||||
count = 0
|
||||
for instance in this_server["device"]:
|
||||
device_id = instance["device_id"]
|
||||
rank_id = instance["rank_id"]
|
||||
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
|
||||
print("rank_id:", rank_id)
|
||||
print("device_id:", device_id)
|
||||
|
||||
start = count * int(avg_core_per_rank)
|
||||
count += 1
|
||||
end = start + core_gap
|
||||
cmdopt = str(start) + "-" + str(end)
|
||||
|
||||
cmd = append_cmd_env(cmd, "DEVICE_ID", str(device_id))
|
||||
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
|
||||
cmd = append_cmd_env(cmd, "DEPLOY_MODE", '0')
|
||||
cmd = append_cmd_env(cmd, "GE_USE_STATIC_MEMORY", '1')
|
||||
|
||||
cmd = append_cmd(cmd, "rm -rf LOG" + str(device_id))
|
||||
cmd = append_cmd(cmd, "mkdir ./LOG" + str(device_id))
|
||||
cmd = append_cmd(cmd, "cp *.py ./LOG" + str(device_id))
|
||||
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(device_id) + "/ms_log")
|
||||
cmd = append_cmd(cmd, "env > ./LOG" + str(device_id) + "/env.log")
|
||||
|
||||
cur_dir = os.getcwd()
|
||||
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(device_id) + "/ms_log")
|
||||
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
|
||||
|
||||
print("core_nums:", cmdopt)
|
||||
print("epoch_size:", str(cfg['epoch_size']))
|
||||
print("data_dir:", data_dir)
|
||||
print("mindrecord_dir:", mindrecord_dir)
|
||||
print("log_file_dir: " + cur_dir + "/LOG" + str(device_id) + "/training_log.txt")
|
||||
|
||||
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(device_id))
|
||||
|
||||
run_cmd = 'taskset -c ' + cmdopt + ' nohup python ' + run_script + " "
|
||||
opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()])
|
||||
if ('device_id' in opt) or ('device_num' in opt) or ('data_dir' in opt):
|
||||
raise ValueError("hyper_parameter_config.ini can not setting 'device_id',"
|
||||
" 'device_num' or 'data_dir'! ")
|
||||
run_cmd += opt
|
||||
run_cmd += " --data_dir=" + data_dir
|
||||
run_cmd += " --mindrecord_dir=" + mindrecord_dir
|
||||
run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \
|
||||
+ str(rank_size) + ' >./training_log.txt 2>&1 &'
|
||||
|
||||
cmd = append_cmd(cmd, run_cmd)
|
||||
cmd = append_cmd(cmd, "cd -")
|
||||
cmd += "\n"
|
||||
|
||||
with open(args.cmd_file, "w") as f:
|
||||
f.write(cmd)
|
||||
|
||||
if __name__ == "__main__":
|
||||
distribute_train()
|
@ -0,0 +1,14 @@
|
||||
[config]
|
||||
distribute=true
|
||||
epoch_size=350
|
||||
enable_save_ckpt=true
|
||||
do_shuffle=true
|
||||
enable_data_sink=true
|
||||
data_sink_steps=50
|
||||
load_checkpoint_path=""
|
||||
save_checkpoint_path=./
|
||||
save_checkpoint_steps=3000
|
||||
save_checkpoint_num=1
|
||||
need_profiler=false
|
||||
profiler_path=./profiler
|
||||
visual_image=false
|
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_distributed_train_ascend.sh DATA_DIR MINDRECORD_DIR RANK_TABLE_FILE"
|
||||
echo "for example: bash run_distributed_train_ascend.sh /path/dataset /path/mindrecord /path/hccl.json"
|
||||
echo "It is better to use absolute path."
|
||||
echo "For hyper parameter, please note that you should customize the scripts:
|
||||
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
|
||||
echo "=============================================================================================================="
|
||||
CUR_DIR=`pwd`
|
||||
|
||||
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
|
||||
--run_script_dir=${CUR_DIR}/train.py \
|
||||
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
|
||||
--data_dir=$1 \
|
||||
--mindrecord_dir=$2 \
|
||||
--hccl_config_dir=$3 \
|
||||
--hccl_time_out=1200 \
|
||||
--cmd_file=distributed_cmd.sh
|
||||
|
||||
bash distributed_cmd.sh
|
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_eval_ascend.sh DEVICE_ID"
|
||||
echo "for example: bash run_standalone_eval_ascend.sh 0"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
# install nms module from third party
|
||||
if python -c "import nms" > /dev/null 2>&1
|
||||
then
|
||||
echo "NMS module already exits, no need reinstall."
|
||||
else
|
||||
echo "NMS module was not found, install it now..."
|
||||
git clone https://github.com/xingyizhou/CenterNet.git
|
||||
cd CenterNet/src/lib/external/
|
||||
make
|
||||
python setup.py install
|
||||
cd -
|
||||
rm -rf CenterNet
|
||||
fi
|
||||
|
||||
python ${PROJECT_DIR}/../eval.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--load_checkpoint_path="" \
|
||||
--data_dir="" \
|
||||
--visual_image=true \
|
||||
--enable_eval=true \
|
||||
--save_result_dir="" \
|
||||
--run_mode=val > log.txt 2>&1 &
|
@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE"
|
||||
echo "for example: bash run_standalone_pretrain_ascend.sh 0 350"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
|
||||
mkdir -p ms_log
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python ${PROJECT_DIR}/../train.py \
|
||||
--distribute=false \
|
||||
--need_profiler=false \
|
||||
--profiler_path=./profiler \
|
||||
--epoch_size=$EPOCH_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--enable_save_ckpt=true \
|
||||
--do_shuffle=true \
|
||||
--enable_data_sink=true \
|
||||
--data_sink_steps=50 \
|
||||
--load_checkpoint_path="" \
|
||||
--save_checkpoint_steps=10000 \
|
||||
--save_checkpoint_num=1 \
|
||||
--data_dir="" \
|
||||
--mindrecord_dir="" \
|
||||
--visual_image=false \
|
||||
--save_result_dir=""> log.txt 2>&1 &
|
@ -0,0 +1,28 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CenterNet Init."""
|
||||
|
||||
from .centernet_pose import GatherMultiPoseFeatureCell, CenterNetMultiPoseLossCell, \
|
||||
CenterNetWithLossScaleCell, CenterNetMultiPoseEval
|
||||
from .dataset import COCOHP
|
||||
from .visual import visual_allimages, visual_image
|
||||
from .decode import MultiPoseDecode
|
||||
from .post_process import convert_eval_format, to_float, resize_detection, post_process, merge_outputs
|
||||
|
||||
__all__ = [
|
||||
"GatherMultiPoseFeatureCell", "CenterNetMultiPoseLossCell", "CenterNetWithLossScaleCell", \
|
||||
"CenterNetMultiPoseEval", "COCOHP", "visual_allimages", "visual_image", "MultiPoseDecode", \
|
||||
"convert_eval_format", "to_float", "resize_detection", "post_process", "merge_outputs"
|
||||
]
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,119 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Post-process functions after decoding
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from src.config import dataset_config as config
|
||||
from .image import get_affine_transform, affine_transform, transform_preds
|
||||
from .visual import coco_box_to_bbox
|
||||
|
||||
try:
|
||||
from nms import soft_nms_39
|
||||
except ImportError:
|
||||
print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n'
|
||||
'and see run_standalone_eval.sh for more details to install it\n')
|
||||
|
||||
_NUM_JOINTS = config.num_joints
|
||||
|
||||
|
||||
def post_process(dets, meta, scale=1):
|
||||
"""rescale detection to original scale"""
|
||||
c, s, h, w = meta['c'], meta['s'], meta['out_height'], meta['out_width']
|
||||
b, K, N = dets.shape
|
||||
assert b == 1, "only single image was post-processed"
|
||||
dets = dets.reshape((K, N))
|
||||
bbox = transform_preds(dets[:, :4].reshape(-1, 2), c, s, (w, h)) / scale
|
||||
pts = transform_preds(dets[:, 5:39].reshape(-1, 2), c, s, (w, h)) / scale
|
||||
top_preds = np.concatenate(
|
||||
[bbox.reshape(-1, 4), dets[:, 4:5],
|
||||
pts.reshape(-1, 34)], axis=1).astype(np.float32).reshape(-1, 39)
|
||||
return top_preds
|
||||
|
||||
|
||||
def merge_outputs(detections, soft_nms=True):
|
||||
"""merge detections together by nms"""
|
||||
results = np.concatenate([detection for detection in detections], axis=0).astype(np.float32)
|
||||
if soft_nms:
|
||||
soft_nms_39(results, Nt=0.5, threshold=0.01, method=2)
|
||||
results = results.tolist()
|
||||
return results
|
||||
|
||||
|
||||
def convert_eval_format(detections, img_id):
|
||||
"""convert detection to annotation json format"""
|
||||
# detections. scores: (b, K); bboxes: (b, K, 4); kps: (b, K, J * 2); clses: (b, K)
|
||||
# only batch_size = 1 is supported
|
||||
detections = np.array(detections).reshape((-1, 39))
|
||||
pred_anno = {"images": [], "annotations": []}
|
||||
num_objs, _ = detections.shape
|
||||
for i in range(num_objs):
|
||||
score = detections[i][4]
|
||||
bbox = detections[i][0:4]
|
||||
bbox[2:4] = bbox[2:4] - bbox[0:2]
|
||||
bbox = list(map(to_float, bbox))
|
||||
keypoints = np.concatenate([
|
||||
np.array(detections[i][5:39], dtype=np.float32).reshape(-1, 2),
|
||||
np.ones((17, 1), dtype=np.float32)], axis=1).reshape(_NUM_JOINTS * 3).tolist()
|
||||
keypoints = list(map(to_float, keypoints))
|
||||
class_id = 1
|
||||
pred = {
|
||||
"image_id": int(img_id),
|
||||
"category_id": int(class_id),
|
||||
"bbox": bbox,
|
||||
"score": to_float(score),
|
||||
"keypoints": keypoints
|
||||
}
|
||||
pred_anno["annotations"].append(pred)
|
||||
if pred_anno["annotations"]:
|
||||
pred_anno["images"].append({"id": int(img_id)})
|
||||
return pred_anno
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""format float data"""
|
||||
return float("{:.2f}".format(x))
|
||||
|
||||
|
||||
def resize_detection(detection, pred, gt):
|
||||
"""resize object annotation info"""
|
||||
height, width = gt[0], gt[1]
|
||||
c = np.array([pred[1] / 2., pred[0] / 2.], dtype=np.float32)
|
||||
s = max(pred[0], pred[1]) * 1.0
|
||||
trans_output = get_affine_transform(c, s, 0, [width, height])
|
||||
|
||||
anns = detection["annotations"]
|
||||
num_objects = len(anns)
|
||||
resized_detection = {"images": detection["images"], "annotations": []}
|
||||
for i in range(num_objects):
|
||||
ann = anns[i]
|
||||
bbox = coco_box_to_bbox(ann['bbox'])
|
||||
pts = np.array(ann['keypoints'], np.float32).reshape(_NUM_JOINTS, 3)
|
||||
|
||||
bbox[:2] = affine_transform(bbox[:2], trans_output)
|
||||
bbox[2:] = affine_transform(bbox[2:], trans_output)
|
||||
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
|
||||
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
|
||||
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
|
||||
for j in range(_NUM_JOINTS):
|
||||
pts[j, :2] = affine_transform(pts[j, :2], trans_output)
|
||||
|
||||
bbox = [bbox[0], bbox[1], w, h]
|
||||
keypoints = pts.reshape(_NUM_JOINTS * 3).tolist()
|
||||
ann["bbox"] = list(map(to_float, bbox))
|
||||
ann["keypoints"] = list(map(to_float, keypoints))
|
||||
resized_detection["annotations"].append(ann)
|
||||
return resize_detection
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,175 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pycocotools.coco as COCO
|
||||
from .config import dataset_config as data_cfg
|
||||
from .image import get_affine_transform, affine_transform
|
||||
|
||||
_NUM_JOINTS = data_cfg.num_joints
|
||||
|
||||
def coco_box_to_bbox(box):
|
||||
"""convert height/width to position coordinates"""
|
||||
bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
|
||||
return bbox
|
||||
|
||||
def resize_image(image, anns, width, height):
|
||||
"""resize image to specified scale"""
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
c = np.array([image.shape[1] / 2., image.shape[0] / 2.], dtype=np.float32)
|
||||
s = max(image.shape[0], image.shape[1]) * 1.0
|
||||
trans_output = get_affine_transform(c, s, 0, [width, height])
|
||||
out_img = cv2.warpAffine(image, trans_output, (width, height), flags=cv2.INTER_LINEAR)
|
||||
|
||||
num_objects = len(anns)
|
||||
resize_anno = []
|
||||
for i in range(num_objects):
|
||||
ann = anns[i]
|
||||
bbox = coco_box_to_bbox(ann['bbox'])
|
||||
pts = np.array(ann['keypoints'], np.float32).reshape(_NUM_JOINTS, 3)
|
||||
|
||||
bbox[:2] = affine_transform(bbox[:2], trans_output)
|
||||
bbox[2:] = affine_transform(bbox[2:], trans_output)
|
||||
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
|
||||
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
|
||||
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
|
||||
if (h > 0 and w > 0):
|
||||
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
|
||||
for j in range(_NUM_JOINTS):
|
||||
pts[j, :2] = affine_transform(pts[j, :2], trans_output)
|
||||
|
||||
bbox = [ct[0] - w / 2, ct[1] - h / 2, w, h, 1]
|
||||
keypoints = pts.reshape(_NUM_JOINTS * 3).tolist()
|
||||
ann["bbox"] = bbox
|
||||
ann["keypoints"] = keypoints
|
||||
gt = ann
|
||||
resize_anno.append(gt)
|
||||
return out_img, resize_anno
|
||||
|
||||
|
||||
def merge_pred(ann_path, mode="val", name="merged_annotations"):
|
||||
"""merge annotation info of each image together"""
|
||||
files = os.listdir(ann_path)
|
||||
data_files = []
|
||||
for file_name in files:
|
||||
if "json" in file_name:
|
||||
data_files.append(os.path.join(ann_path, file_name))
|
||||
pred = {"images": [], "annotations": []}
|
||||
for file in data_files:
|
||||
anno = json.load(open(file, 'r'))
|
||||
if "images" in anno:
|
||||
for img in anno["images"]:
|
||||
pred["images"].append(img)
|
||||
if "annotations" in anno:
|
||||
for ann in anno["annotations"]:
|
||||
pred["annotations"].append(ann)
|
||||
json.dump(pred, open('{}/{}_{}.json'.format(ann_path, name, mode), 'w'))
|
||||
|
||||
|
||||
def visual(ann_path, image_path, save_path, ratio=1, mode="val", name="merged_annotations"):
|
||||
"""visulize all images based on dataset and annotations info"""
|
||||
merge_pred(ann_path, mode, name)
|
||||
ann_path = os.path.join(ann_path, name + '_' + mode + '.json')
|
||||
visual_allimages(ann_path, image_path, save_path, ratio)
|
||||
|
||||
|
||||
def visual_allimages(anno_file, image_path, save_path, ratio=1):
|
||||
"""visualize all images and annotations info"""
|
||||
coco = COCO.COCO(anno_file)
|
||||
image_ids = coco.getImgIds()
|
||||
images = []
|
||||
anns = {}
|
||||
for img_id in image_ids:
|
||||
idxs = coco.getAnnIds(imgIds=[img_id])
|
||||
if idxs:
|
||||
images.append(img_id)
|
||||
anns[img_id] = idxs
|
||||
|
||||
for img_id in images:
|
||||
file_name = coco.loadImgs(ids=[img_id])[0]['file_name']
|
||||
img_path = os.path.join(image_path, file_name)
|
||||
annos = coco.loadAnns(anns[img_id])
|
||||
img = cv2.imread(img_path)
|
||||
return visual_image(img, annos, save_path, ratio)
|
||||
|
||||
|
||||
def visual_image(img, annos, save_path, ratio=None, height=None, width=None, name=None, score_threshold=0.01):
|
||||
"""visualize image and annotations info"""
|
||||
# annos: list type, in which all the element is dict
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
if height is not None and width is not None and (height != h or width != w):
|
||||
img, annos = resize_image(img, annos, width, height)
|
||||
elif ratio not in (None, 1):
|
||||
img, annos = resize_image(img, annos, w * ratio, h * ratio)
|
||||
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
num_objects = len(annos)
|
||||
num = 0
|
||||
for i in range(num_objects):
|
||||
ann = annos[i]
|
||||
bbox = coco_box_to_bbox(ann['bbox'])
|
||||
if "score" in ann:
|
||||
score = ann["score"]
|
||||
if score < score_threshold and num != 0:
|
||||
continue
|
||||
num += 1
|
||||
txt = ("p" + "{:.2f}".format(ann["score"]))
|
||||
cv2.putText(img, txt, (bbox[0], bbox[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
||||
|
||||
ct = (int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2))
|
||||
cv2.circle(img, ct, 2, (0, 255, 0), thickness=-1, lineType=cv2.FILLED)
|
||||
bbox = np.array(bbox, dtype=np.int32).tolist()
|
||||
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
|
||||
|
||||
keypoints = ann["keypoints"]
|
||||
keypoints = np.array(keypoints, dtype=np.int32).reshape(_NUM_JOINTS, 3).tolist()
|
||||
|
||||
left_part = [0, 1, 3, 5, 7, 9, 11, 13, 15]
|
||||
right_part = [0, 2, 4, 6, 8, 10, 12, 14, 16]
|
||||
for pair in data_cfg.edges:
|
||||
partA = pair[0]
|
||||
partB = pair[1]
|
||||
if partA in left_part and partB in left_part:
|
||||
color = (255, 0, 0)
|
||||
elif partA in right_part and partB in right_part:
|
||||
color = (0, 0, 255)
|
||||
else:
|
||||
color = (139, 0, 255)
|
||||
p_a = tuple(keypoints[partA][:2])
|
||||
p_b = tuple(keypoints[partB][:2])
|
||||
mask_a = keypoints[partA][2]
|
||||
mask_b = keypoints[partB][2]
|
||||
if (p_a[0] >= 0 and p_a[0] < w and p_a[1] >= 0 and p_a[1] < h and
|
||||
p_b[0] >= 0 and p_b[0] < w and p_b[1] >= 0 and p_b[1] < h and
|
||||
mask_a * mask_b > 0):
|
||||
cv2.line(img, p_a, p_b, color, 2)
|
||||
cv2.circle(img, p_a, 3, color, thickness=-1, lineType=cv2.FILLED)
|
||||
cv2.circle(img, p_b, 3, color, thickness=-1, lineType=cv2.FILLED)
|
||||
if annos and "image_id" in annos[0]:
|
||||
img_id = annos[0]["image_id"]
|
||||
else:
|
||||
img_id = random.randint(0, 9999999)
|
||||
if name is None:
|
||||
image_name = "cv_image_" + str(img_id) + ".png"
|
||||
else:
|
||||
image_name = "cv_image_" + str(img_id) + name + ".png"
|
||||
cv2.imwrite("{}/{}".format(save_path, image_name), img)
|
@ -0,0 +1,202 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Train CenterNet and get network model files(.ckpt)
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.communication.management import get_rank
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.profiler import Profiler
|
||||
from src.dataset import COCOHP
|
||||
from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
|
||||
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
|
||||
from src.config import dataset_config, net_config, train_config
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='CenterNet training')
|
||||
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
|
||||
help="Run distribute, default is false.")
|
||||
parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
|
||||
help="Profiling to parsing runtime info, default is false.")
|
||||
parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
|
||||
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
|
||||
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
|
||||
"i.e. run all steps according to epoch number.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable save checkpoint, default is true.")
|
||||
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable shuffle for dataset, default is true.")
|
||||
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
|
||||
help="Enable data sink, default is true.")
|
||||
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
||||
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
||||
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
||||
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
|
||||
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
|
||||
parser.add_argument("--mindrecord_dir", type=str, default="",
|
||||
help="Mindrecord files directory. If is empty, mindrecord format files will be generated"
|
||||
"based on the original dataset and annotation information. If mindrecord_dir isn't empty,"
|
||||
"mindrecord_dir will be used inplace of data_dir and anno_path.")
|
||||
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
|
||||
"the absolute image path is joined by the data_dir "
|
||||
"and the relative path in anno_path")
|
||||
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
|
||||
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
|
||||
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
|
||||
def _set_parallel_all_reduce_split():
|
||||
"""set centernet all_reduce fusion split"""
|
||||
if net_config.last_level == 5:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[16, 56, 96, 136, 175])
|
||||
elif net_config.last_level == 6:
|
||||
context.set_auto_parallel_context(all_reduce_fusion_config=[18, 59, 100, 141, 182])
|
||||
else:
|
||||
raise ValueError("The total num of allreduced grads for last level = {} is unknown,"
|
||||
"please re-split after known the true value".format(net_config.last_level))
|
||||
|
||||
|
||||
def _get_params_groups(network, optimizer):
|
||||
"""
|
||||
Get param groups
|
||||
"""
|
||||
params = network.trainable_params()
|
||||
decay_params = list(filter(lambda x: not optimizer.decay_filter(x), params))
|
||||
other_params = list(filter(optimizer.decay_filter, params))
|
||||
group_params = [{'params': decay_params, 'weight_decay': optimizer.weight_decay},
|
||||
{'params': other_params, 'weight_decay': 0.0},
|
||||
{'order_params': params}]
|
||||
return group_params
|
||||
|
||||
|
||||
def _get_optimizer(network, dataset_size):
|
||||
"""get optimizer, only support Adam right now."""
|
||||
if train_config.optimizer == 'Adam':
|
||||
group_params = _get_params_groups(network, train_config.Adam)
|
||||
if train_config.lr_schedule == "PolyDecay":
|
||||
lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
|
||||
end_learning_rate=train_config.PolyDecay.end_learning_rate,
|
||||
warmup_steps=train_config.PolyDecay.warmup_steps,
|
||||
decay_steps=args_opt.train_steps,
|
||||
power=train_config.PolyDecay.power)
|
||||
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
|
||||
elif train_config.lr_schedule == "MultiDecay":
|
||||
multi_epochs = train_config.MultiDecay.multi_epochs
|
||||
if not isinstance(multi_epochs, (list, tuple)):
|
||||
raise TypeError("multi_epochs must be list or tuple.")
|
||||
if not multi_epochs:
|
||||
multi_epochs = [args_opt.epoch_size]
|
||||
lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
|
||||
warmup_steps=train_config.MultiDecay.warmup_steps,
|
||||
multi_epochs=multi_epochs,
|
||||
steps_per_epoch=dataset_size,
|
||||
factor=train_config.MultiDecay.factor)
|
||||
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.MultiDecay.eps, loss_scale=1.0)
|
||||
else:
|
||||
raise ValueError("Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]".
|
||||
format(train_config.optimizer))
|
||||
else:
|
||||
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
|
||||
format(train_config.optimizer))
|
||||
return optimizer
|
||||
|
||||
|
||||
def train():
|
||||
"""training CenterNet"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
||||
context.set_context(enable_auto_mixed_precision=False)
|
||||
context.set_context(reserve_class_name_in_scope=False)
|
||||
context.set_context(save_graphs=False)
|
||||
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
device_num = args_opt.device_num
|
||||
rank = args_opt.device_id % device_num
|
||||
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
|
||||
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
_set_parallel_all_reduce_split()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
num_workers = device_num * 8
|
||||
# Start create dataset!
|
||||
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
|
||||
logger.info("Begin creating dataset for CenterNet")
|
||||
prefix = "coco_hp.train.mind"
|
||||
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode="train")
|
||||
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
|
||||
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, prefix, batch_size=train_config.batch_size,
|
||||
device_num=device_num, rank=rank, num_parallel_workers=num_workers,
|
||||
do_shuffle=args_opt.do_shuffle == 'true')
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
logger.info("Create dataset done!")
|
||||
|
||||
net_with_loss = CenterNetMultiPoseLossCell(net_config)
|
||||
|
||||
new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
|
||||
if args_opt.train_steps > 0:
|
||||
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
|
||||
else:
|
||||
args_opt.train_steps = args_opt.epoch_size * dataset_size
|
||||
logger.info("train steps: {}".format(args_opt.train_steps))
|
||||
|
||||
optimizer = _get_optimizer(net_with_loss, dataset_size)
|
||||
|
||||
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size)]
|
||||
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
||||
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
|
||||
directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
|
||||
callback.append(ckpoint_cb)
|
||||
|
||||
if args_opt.load_checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
||||
load_param_into_net(net_with_loss, param_dict)
|
||||
|
||||
net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
|
||||
sens=train_config.loss_scale_value)
|
||||
|
||||
model = Model(net_with_grads)
|
||||
|
||||
model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
|
||||
sink_size=args_opt.data_sink_steps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.need_profiler == "true":
|
||||
profiler = Profiler(output_path=args_opt.profiler_path)
|
||||
set_seed(0)
|
||||
train()
|
||||
if args_opt.need_profiler == "true":
|
||||
profiler.analyse()
|
Loading…
Reference in new issue