add centernet scripts

pull/10038/head
shibeiji 5 years ago
parent 67c3fded73
commit 6afa96de4d

@ -70,6 +70,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [FaceQualityAssessment](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceQualityAssessment/README.md)
- [FaceRecognitionForTracking](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognitionForTracking/README.md)
- [FaceRecognition](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/FaceRecognition/README.md)
- [CenterNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CenterNet evaluation script.
"""
import os
import time
import copy
import json
import argparse
import cv2
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore.log as logger
from src import COCOHP, CenterNetMultiPoseEval
from src import convert_eval_format, post_process, merge_outputs
from src import visual_image
from src.config import dataset_config, net_config, eval_config
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet evaluation')
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the data_dir "
"and the relative path in anno_path")
parser.add_argument("--run_mode", type=str, default="test", help="test or validation, default is test.")
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
parser.add_argument("--enable_eval", type=str, default="true", help="Wether evaluate accuracy after prediction")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def predict():
'''
Predict function
'''
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
logger.info("Begin creating {} dataset".format(args_opt.run_mode))
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode=args_opt.run_mode)
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir,
keep_res=eval_config.keep_res, flip_test=eval_config.flip_test)
dataset = coco.create_eval_dataset()
net_for_eval = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
net_for_eval.set_train(False)
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_for_eval, param_dict)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
if args_opt.visual_image == "true":
save_pred_image_path = os.path.join(save_path, "pred_image")
if not os.path.exists(save_pred_image_path):
os.makedirs(save_pred_image_path)
save_gt_image_path = os.path.join(save_path, "gt_image")
if not os.path.exists(save_gt_image_path):
os.makedirs(save_gt_image_path)
total_nums = dataset.get_dataset_size()
print("\n========================================\n")
print("Total images num: ", total_nums)
print("Processing, please wait a moment.")
pred_annos = {"images": [], "annotations": []}
index = 0
for data in dataset.create_dict_iterator(num_epochs=1):
index += 1
image = data['image']
image_id = data['image_id'].asnumpy().reshape((-1))[0]
# run prediction
start = time.time()
detections = []
for scale in eval_config.multi_scales:
images, meta = coco.pre_process_for_test(image.asnumpy(), image_id, scale)
detection = net_for_eval(Tensor(images))
dets = post_process(detection.asnumpy(), meta, scale)
detections.append(dets)
end = time.time()
print("Image {}/{} id: {} cost time {} ms".format(index, total_nums, image_id, (end - start) * 1000.))
# post-process
soft_nms = eval_config.soft_nms or len(eval_config.multi_scales) > 0
detections = merge_outputs(detections, soft_nms)
# get prediction result
pred_json = convert_eval_format(detections, image_id)
gt_image_info = coco.coco.loadImgs([image_id])
for image_info in pred_json["images"]:
pred_annos["images"].append(image_info)
for image_anno in pred_json["annotations"]:
pred_annos["annotations"].append(image_anno)
if args_opt.visual_image == "true":
img_file = os.path.join(coco.image_path, gt_image_info[0]['file_name'])
gt_image = cv2.imread(img_file)
if args_opt.run_mode != "test":
annos = coco.coco.loadAnns(coco.anns[image_id])
visual_image(copy.deepcopy(gt_image), annos, save_gt_image_path)
anno = copy.deepcopy(pred_json["annotations"])
visual_image(gt_image, anno, save_pred_image_path, score_threshold=eval_config.score_thresh)
# save results
save_path = os.path.join(args_opt.save_result_dir, args_opt.run_mode)
if not os.path.exists(save_path):
os.makedirs(save_path)
pred_anno_file = os.path.join(save_path, '{}_pred_result.json').format(args_opt.run_mode)
json.dump(pred_annos, open(pred_anno_file, 'w'))
pred_res_file = os.path.join(save_path, '{}_pred_eval.json').format(args_opt.run_mode)
json.dump(pred_annos["annotations"], open(pred_res_file, 'w'))
if args_opt.run_mode != "test" and args_opt.enable_eval:
run_eval(coco.annot_path, pred_res_file)
def run_eval(gt_anno, pred_anno):
"""evaluation by coco api"""
coco = COCO(gt_anno)
coco_dets = coco.loadRes(pred_anno)
coco_eval = COCOeval(coco, coco_dets, "keypoints")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
coco_eval = COCOeval(coco, coco_dets, "bbox")
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
if __name__ == "__main__":
predict()

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Export CenterNet mindir model.
"""
import argparse
import numpy as np
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src import CenterNetMultiPoseEval
from src.config import net_config, eval_config, export_config
parser = argparse.ArgumentParser(description='centernet export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == '__main__':
net = CenterNetMultiPoseEval(net_config, eval_config.flip_test, eval_config.K)
net.set_train(False)
param_dict = load_checkpoint(export_config.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
input_shape = [1, 3, export_config.input_res[0], export_config.input_res[1]]
input_data = Tensor(np.random.uniform(-1.0, 1.0, size=input_shape).astype(np.float32))
export(net, input_data, file_name=export_config.export_name, file_format=export_config.export_format)

@ -0,0 +1,51 @@
# Run distribute train
## description
The number of Ascend accelerators can be automatically allocated based on the device_num set in hccl config file, You don not need to specify that.
## how to use
For example, if we want to generate the launch command of the distributed training of CenterNet model on Ascend accelerators, we can run the following command in `/centernet/` dir:
```python
python ./scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py --run_script_dir ./train.py --hyper_parameter_config_dir ./scripts/ascend_distributed_launcher/hyper_parameter_config.ini --data_dir /path/dataset/ --mindrecord_dir /path/mindrecord_dataset/ --hccl_config_dir model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
```
output:
```text
hccl_config_dir: model_zoo/utils/hccl_tools/hccl_2p_56_x.x.x.x.json
the number of logical core: 192
avg_core_per_rank: 96
rank_size: 2
start training for rank 0, device 5:
rank_id: 0
device_id: 5
core nums: 0-95
epoch_size: 350
data_dir: /path/dataset/
mindrecord_dir: /path/mindrecord_dataset/
log file dir: ./LOG5/training_log.txt
start training for rank 1, device 6:
rank_id: 1
device_id: 6
core nums: 96-191
epoch_size: 350
data_dir: /path/dataset/
mindrecord_dir: /path/mindrecord_dataset/
log file dir: ./LOG6/training_log.txt
```
## Note
1. Note that `hccl_2p_56_x.x.x.x.json` can use [hccl_tools.py](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate.
2. For hyper parameter, please note that you should customize the scripts `hyper_parameter_config.ini`. Please note that these two hyper parameters are not allowed to be configured here:
- device_id
- device_num
- data_dir
3. For Other Model, please note that you should customize the option `run_script` and Corresponding `hyper_parameter_config.ini`.

@ -0,0 +1,170 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""distribute pretrain script"""
import os
import json
import configparser
import multiprocessing
from argparse import ArgumentParser
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
parser = ArgumentParser(description="mindspore distributed training")
parser.add_argument("--run_script_dir", type=str, default="",
help="Run script path, it is better to use absolute path")
parser.add_argument("--hyper_parameter_config_dir", type=str, default="",
help="Hyper Parameter config path, it is better to use absolute path")
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train",
help="Mindrecord directory. If the mindrecord_dir is empty, it wil generate mindrecord file by "
"data_dir and anno_path. Note if mindrecord_dir isn't empty, it will use mindrecord_dir "
"rather than data_dir and anno_path. Default is ./Mindrecord_train")
parser.add_argument("--data_dir", type=str, default="",
help="Data path, it is better to use absolute path")
parser.add_argument("--hccl_config_dir", type=str, default="",
help="Hccl config path, it is better to use absolute path")
parser.add_argument("--cmd_file", type=str, default="distributed_cmd.sh",
help="Path of the generated cmd file.")
parser.add_argument("--hccl_time_out", type=int, default=120,
help="Seconds to determine the hccl time out,"
"default: 120, which is the same as hccl default config")
args = parser.parse_args()
return args
def append_cmd(cmd, s):
cmd += s
cmd += "\n"
return cmd
def append_cmd_env(cmd, key, value):
return append_cmd(cmd, "export " + str(key) + "=" + str(value))
def distribute_train():
"""
distribute pretrain scripts. The number of Ascend accelerators can be automatically allocated
based on the device_num set in hccl config file, You don not need to specify that.
"""
cmd = ""
print("start", __file__)
args = parse_args()
run_script = args.run_script_dir
data_dir = args.data_dir
mindrecord_dir = args.mindrecord_dir
cf = configparser.ConfigParser()
cf.read(args.hyper_parameter_config_dir)
cfg = dict(cf.items("config"))
print("hccl_config_dir:", args.hccl_config_dir)
print("hccl_time_out:", args.hccl_time_out)
cmd = append_cmd_env(cmd, 'HCCL_CONNECT_TIMEOUT', args.hccl_time_out)
cmd = append_cmd_env(cmd, 'RANK_TABLE_FILE', args.hccl_config_dir)
cores = multiprocessing.cpu_count()
print("the number of logical core:", cores)
# get device_ips
device_ips = {}
with open('/etc/hccn.conf', 'r') as fin:
for hccn_item in fin.readlines():
if hccn_item.strip().startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip.strip()
with open(args.hccl_config_dir, "r", encoding="utf-8") as fin:
hccl_config = json.loads(fin.read())
rank_size = 0
for server in hccl_config["server_list"]:
rank_size += len(server["device"])
if server["device"][0]["device_ip"] in device_ips.values():
this_server = server
cmd = append_cmd_env(cmd, "RANK_SIZE", str(rank_size))
print("total rank size:", rank_size)
print("this server rank size:", len(this_server["device"]))
avg_core_per_rank = int(int(cores) / len(this_server["device"]))
core_gap = avg_core_per_rank - 1
print("avg_core_per_rank:", avg_core_per_rank)
count = 0
for instance in this_server["device"]:
device_id = instance["device_id"]
rank_id = instance["rank_id"]
print("\nstart training for rank " + str(rank_id) + ", device " + str(device_id) + ":")
print("rank_id:", rank_id)
print("device_id:", device_id)
start = count * int(avg_core_per_rank)
count += 1
end = start + core_gap
cmdopt = str(start) + "-" + str(end)
cmd = append_cmd_env(cmd, "DEVICE_ID", str(device_id))
cmd = append_cmd_env(cmd, "RANK_ID", str(rank_id))
cmd = append_cmd_env(cmd, "DEPLOY_MODE", '0')
cmd = append_cmd_env(cmd, "GE_USE_STATIC_MEMORY", '1')
cmd = append_cmd(cmd, "rm -rf LOG" + str(device_id))
cmd = append_cmd(cmd, "mkdir ./LOG" + str(device_id))
cmd = append_cmd(cmd, "cp *.py ./LOG" + str(device_id))
cmd = append_cmd(cmd, "mkdir -p ./LOG" + str(device_id) + "/ms_log")
cmd = append_cmd(cmd, "env > ./LOG" + str(device_id) + "/env.log")
cur_dir = os.getcwd()
cmd = append_cmd_env(cmd, "GLOG_log_dir", cur_dir + "/LOG" + str(device_id) + "/ms_log")
cmd = append_cmd_env(cmd, "GLOG_logtostderr", "0")
print("core_nums:", cmdopt)
print("epoch_size:", str(cfg['epoch_size']))
print("data_dir:", data_dir)
print("mindrecord_dir:", mindrecord_dir)
print("log_file_dir: " + cur_dir + "/LOG" + str(device_id) + "/training_log.txt")
cmd = append_cmd(cmd, "cd " + cur_dir + "/LOG" + str(device_id))
run_cmd = 'taskset -c ' + cmdopt + ' nohup python ' + run_script + " "
opt = " ".join(["--" + key + "=" + str(cfg[key]) for key in cfg.keys()])
if ('device_id' in opt) or ('device_num' in opt) or ('data_dir' in opt):
raise ValueError("hyper_parameter_config.ini can not setting 'device_id',"
" 'device_num' or 'data_dir'! ")
run_cmd += opt
run_cmd += " --data_dir=" + data_dir
run_cmd += " --mindrecord_dir=" + mindrecord_dir
run_cmd += ' --device_id=' + str(device_id) + ' --device_num=' \
+ str(rank_size) + ' >./training_log.txt 2>&1 &'
cmd = append_cmd(cmd, run_cmd)
cmd = append_cmd(cmd, "cd -")
cmd += "\n"
with open(args.cmd_file, "w") as f:
f.write(cmd)
if __name__ == "__main__":
distribute_train()

@ -0,0 +1,14 @@
[config]
distribute=true
epoch_size=350
enable_save_ckpt=true
do_shuffle=true
enable_data_sink=true
data_sink_steps=50
load_checkpoint_path=""
save_checkpoint_path=./
save_checkpoint_steps=3000
save_checkpoint_num=1
need_profiler=false
profiler_path=./profiler
visual_image=false

@ -0,0 +1,36 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distributed_train_ascend.sh DATA_DIR MINDRECORD_DIR RANK_TABLE_FILE"
echo "for example: bash run_distributed_train_ascend.sh /path/dataset /path/mindrecord /path/hccl.json"
echo "It is better to use absolute path."
echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "=============================================================================================================="
CUR_DIR=`pwd`
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_train_cmd.py \
--run_script_dir=${CUR_DIR}/train.py \
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \
--data_dir=$1 \
--mindrecord_dir=$2 \
--hccl_config_dir=$3 \
--hccl_time_out=1200 \
--cmd_file=distributed_cmd.sh
bash distributed_cmd.sh

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_eval_ascend.sh DEVICE_ID"
echo "for example: bash run_standalone_eval_ascend.sh 0"
echo "=============================================================================================================="
DEVICE_ID=$1
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
# install nms module from third party
if python -c "import nms" > /dev/null 2>&1
then
echo "NMS module already exits, no need reinstall."
else
echo "NMS module was not found, install it now..."
git clone https://github.com/xingyizhou/CenterNet.git
cd CenterNet/src/lib/external/
make
python setup.py install
cd -
rm -rf CenterNet
fi
python ${PROJECT_DIR}/../eval.py \
--device_id=$DEVICE_ID \
--load_checkpoint_path="" \
--data_dir="" \
--visual_image=true \
--enable_eval=true \
--save_result_dir="" \
--run_mode=val > log.txt 2>&1 &

@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain_ascend.sh DEVICE_ID EPOCH_SIZE"
echo "for example: bash run_standalone_pretrain_ascend.sh 0 350"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python ${PROJECT_DIR}/../train.py \
--distribute=false \
--need_profiler=false \
--profiler_path=./profiler \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_save_ckpt=true \
--do_shuffle=true \
--enable_data_sink=true \
--data_sink_steps=50 \
--load_checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir="" \
--mindrecord_dir="" \
--visual_image=false \
--save_result_dir=""> log.txt 2>&1 &

@ -0,0 +1,28 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CenterNet Init."""
from .centernet_pose import GatherMultiPoseFeatureCell, CenterNetMultiPoseLossCell, \
CenterNetWithLossScaleCell, CenterNetMultiPoseEval
from .dataset import COCOHP
from .visual import visual_allimages, visual_image
from .decode import MultiPoseDecode
from .post_process import convert_eval_format, to_float, resize_detection, post_process, merge_outputs
__all__ = [
"GatherMultiPoseFeatureCell", "CenterNetMultiPoseLossCell", "CenterNetWithLossScaleCell", \
"CenterNetMultiPoseEval", "COCOHP", "visual_allimages", "visual_image", "MultiPoseDecode", \
"convert_eval_format", "to_float", "resize_detection", "post_process", "merge_outputs"
]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,122 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in dataset.py, train.py eval.py
"""
import numpy as np
from easydict import EasyDict as edict
dataset_config = edict({
'num_classes': 1,
'num_joints': 17,
'max_objs': 32,
'input_res': [512, 512],
'output_res': [128, 128],
'rand_crop': False,
'shift': 0.1,
'scale': 0.4,
'aug_rot': 0.0,
'rotate': 0,
'flip_prop': 0.5,
'color_aug': False,
'mean': np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32),
'std': np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32),
'flip_idx': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]],
'edges': [[0, 1], [0, 2], [1, 3], [2, 4], [4, 6], [3, 5], [5, 6],
[5, 7], [7, 9], [6, 8], [8, 10], [6, 12], [5, 11], [11, 12],
[12, 14], [14, 16], [11, 13], [13, 15]],
'eig_val': np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32),
'eig_vec': np.array([[-0.58752847, -0.69563484, 0.41340352],
[-0.5832747, 0.00994535, -0.81221408],
[-0.56089297, 0.71832671, 0.41158938]], dtype=np.float32),
'categories': [{"supercategory": "person",
"id": 1,
"name": "person",
"keypoints": ["nose", "left_eye", "right_eye", "left_ear", "right_ear",
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
"left_wrist", "right_wrist", "left_hip", "right_hip",
"left_knee", "right_knee", "left_ankle", "right_ankle"],
"skeleton": [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13],
[6, 7], [6, 8], [7, 9], [8, 10], [9, 11], [2, 3], [1, 2], [1, 3],
[2, 4], [3, 5], [4, 6], [5, 7]]}],
})
net_config = edict({
'down_ratio': 4,
'last_level': 6,
'final_kernel': 1,
'stage_levels': [1, 1, 1, 2, 2, 1],
'stage_channels': [16, 32, 64, 128, 256, 512],
'head_conv': 256,
'dense_hp': True,
'hm_hp': True,
'reg_hp_offset': True,
'reg_offset': True,
'hm_weight': 1,
'off_weight': 1,
'wh_weight': 0.1,
'hp_weight': 1,
'hm_hp_weight': 1,
'mse_loss': False,
'reg_loss': 'l1',
})
train_config = edict({
'batch_size': 32,
'loss_scale_value': 1024,
'optimizer': 'Adam',
'lr_schedule': 'MultiDecay',
'Adam': edict({
'weight_decay': 0.0,
'decay_filter': lambda x: x.name.endswith('.bias') or x.name.endswith('.beta') or x.name.endswith('.gamma'),
}),
'PolyDecay': edict({
'learning_rate': 1.2e-4,
'end_learning_rate': 5e-7,
'power': 5.0,
'eps': 1e-7,
'warmup_steps': 2000,
}),
'MultiDecay': edict({
'learning_rate': 1.2e-4,
'eps': 1e-7,
'warmup_steps': 2000,
'multi_epochs': [270, 300],
'factor': 10,
})
})
eval_config = edict({
'flip_test': False,
'soft_nms': False,
'keep_res': True,
'multi_scales': [1.0],
'pad': 31,
'K': 100,
'score_thresh': 0.3
})
export_config = edict({
'input_res': dataset_config.input_res,
'ckpt_file': "./ckpt_file.ckpt",
'export_format': "MINDIR",
'export_name': "CenterNet_MultiPose",
})

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,119 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Post-process functions after decoding
"""
import numpy as np
from src.config import dataset_config as config
from .image import get_affine_transform, affine_transform, transform_preds
from .visual import coco_box_to_bbox
try:
from nms import soft_nms_39
except ImportError:
print('NMS not installed! Do \n cd $CenterNet_ROOT/scripts/ \n'
'and see run_standalone_eval.sh for more details to install it\n')
_NUM_JOINTS = config.num_joints
def post_process(dets, meta, scale=1):
"""rescale detection to original scale"""
c, s, h, w = meta['c'], meta['s'], meta['out_height'], meta['out_width']
b, K, N = dets.shape
assert b == 1, "only single image was post-processed"
dets = dets.reshape((K, N))
bbox = transform_preds(dets[:, :4].reshape(-1, 2), c, s, (w, h)) / scale
pts = transform_preds(dets[:, 5:39].reshape(-1, 2), c, s, (w, h)) / scale
top_preds = np.concatenate(
[bbox.reshape(-1, 4), dets[:, 4:5],
pts.reshape(-1, 34)], axis=1).astype(np.float32).reshape(-1, 39)
return top_preds
def merge_outputs(detections, soft_nms=True):
"""merge detections together by nms"""
results = np.concatenate([detection for detection in detections], axis=0).astype(np.float32)
if soft_nms:
soft_nms_39(results, Nt=0.5, threshold=0.01, method=2)
results = results.tolist()
return results
def convert_eval_format(detections, img_id):
"""convert detection to annotation json format"""
# detections. scores: (b, K); bboxes: (b, K, 4); kps: (b, K, J * 2); clses: (b, K)
# only batch_size = 1 is supported
detections = np.array(detections).reshape((-1, 39))
pred_anno = {"images": [], "annotations": []}
num_objs, _ = detections.shape
for i in range(num_objs):
score = detections[i][4]
bbox = detections[i][0:4]
bbox[2:4] = bbox[2:4] - bbox[0:2]
bbox = list(map(to_float, bbox))
keypoints = np.concatenate([
np.array(detections[i][5:39], dtype=np.float32).reshape(-1, 2),
np.ones((17, 1), dtype=np.float32)], axis=1).reshape(_NUM_JOINTS * 3).tolist()
keypoints = list(map(to_float, keypoints))
class_id = 1
pred = {
"image_id": int(img_id),
"category_id": int(class_id),
"bbox": bbox,
"score": to_float(score),
"keypoints": keypoints
}
pred_anno["annotations"].append(pred)
if pred_anno["annotations"]:
pred_anno["images"].append({"id": int(img_id)})
return pred_anno
def to_float(x):
"""format float data"""
return float("{:.2f}".format(x))
def resize_detection(detection, pred, gt):
"""resize object annotation info"""
height, width = gt[0], gt[1]
c = np.array([pred[1] / 2., pred[0] / 2.], dtype=np.float32)
s = max(pred[0], pred[1]) * 1.0
trans_output = get_affine_transform(c, s, 0, [width, height])
anns = detection["annotations"]
num_objects = len(anns)
resized_detection = {"images": detection["images"], "annotations": []}
for i in range(num_objects):
ann = anns[i]
bbox = coco_box_to_bbox(ann['bbox'])
pts = np.array(ann['keypoints'], np.float32).reshape(_NUM_JOINTS, 3)
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
for j in range(_NUM_JOINTS):
pts[j, :2] = affine_transform(pts[j, :2], trans_output)
bbox = [bbox[0], bbox[1], w, h]
keypoints = pts.reshape(_NUM_JOINTS * 3).tolist()
ann["bbox"] = list(map(to_float, bbox))
ann["keypoints"] = list(map(to_float, keypoints))
resized_detection["annotations"].append(ann)
return resize_detection

File diff suppressed because it is too large Load Diff

@ -0,0 +1,175 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py
"""
import os
import json
import random
import cv2
import numpy as np
import pycocotools.coco as COCO
from .config import dataset_config as data_cfg
from .image import get_affine_transform, affine_transform
_NUM_JOINTS = data_cfg.num_joints
def coco_box_to_bbox(box):
"""convert height/width to position coordinates"""
bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32)
return bbox
def resize_image(image, anns, width, height):
"""resize image to specified scale"""
h, w = image.shape[0], image.shape[1]
c = np.array([image.shape[1] / 2., image.shape[0] / 2.], dtype=np.float32)
s = max(image.shape[0], image.shape[1]) * 1.0
trans_output = get_affine_transform(c, s, 0, [width, height])
out_img = cv2.warpAffine(image, trans_output, (width, height), flags=cv2.INTER_LINEAR)
num_objects = len(anns)
resize_anno = []
for i in range(num_objects):
ann = anns[i]
bbox = coco_box_to_bbox(ann['bbox'])
pts = np.array(ann['keypoints'], np.float32).reshape(_NUM_JOINTS, 3)
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[0::2] = np.clip(bbox[0::2], 0, width - 1)
bbox[1::2] = np.clip(bbox[1::2], 0, height - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if (h > 0 and w > 0):
ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
for j in range(_NUM_JOINTS):
pts[j, :2] = affine_transform(pts[j, :2], trans_output)
bbox = [ct[0] - w / 2, ct[1] - h / 2, w, h, 1]
keypoints = pts.reshape(_NUM_JOINTS * 3).tolist()
ann["bbox"] = bbox
ann["keypoints"] = keypoints
gt = ann
resize_anno.append(gt)
return out_img, resize_anno
def merge_pred(ann_path, mode="val", name="merged_annotations"):
"""merge annotation info of each image together"""
files = os.listdir(ann_path)
data_files = []
for file_name in files:
if "json" in file_name:
data_files.append(os.path.join(ann_path, file_name))
pred = {"images": [], "annotations": []}
for file in data_files:
anno = json.load(open(file, 'r'))
if "images" in anno:
for img in anno["images"]:
pred["images"].append(img)
if "annotations" in anno:
for ann in anno["annotations"]:
pred["annotations"].append(ann)
json.dump(pred, open('{}/{}_{}.json'.format(ann_path, name, mode), 'w'))
def visual(ann_path, image_path, save_path, ratio=1, mode="val", name="merged_annotations"):
"""visulize all images based on dataset and annotations info"""
merge_pred(ann_path, mode, name)
ann_path = os.path.join(ann_path, name + '_' + mode + '.json')
visual_allimages(ann_path, image_path, save_path, ratio)
def visual_allimages(anno_file, image_path, save_path, ratio=1):
"""visualize all images and annotations info"""
coco = COCO.COCO(anno_file)
image_ids = coco.getImgIds()
images = []
anns = {}
for img_id in image_ids:
idxs = coco.getAnnIds(imgIds=[img_id])
if idxs:
images.append(img_id)
anns[img_id] = idxs
for img_id in images:
file_name = coco.loadImgs(ids=[img_id])[0]['file_name']
img_path = os.path.join(image_path, file_name)
annos = coco.loadAnns(anns[img_id])
img = cv2.imread(img_path)
return visual_image(img, annos, save_path, ratio)
def visual_image(img, annos, save_path, ratio=None, height=None, width=None, name=None, score_threshold=0.01):
"""visualize image and annotations info"""
# annos: list type, in which all the element is dict
h, w = img.shape[0], img.shape[1]
if height is not None and width is not None and (height != h or width != w):
img, annos = resize_image(img, annos, width, height)
elif ratio not in (None, 1):
img, annos = resize_image(img, annos, w * ratio, h * ratio)
h, w = img.shape[0], img.shape[1]
num_objects = len(annos)
num = 0
for i in range(num_objects):
ann = annos[i]
bbox = coco_box_to_bbox(ann['bbox'])
if "score" in ann:
score = ann["score"]
if score < score_threshold and num != 0:
continue
num += 1
txt = ("p" + "{:.2f}".format(ann["score"]))
cv2.putText(img, txt, (bbox[0], bbox[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
ct = (int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2))
cv2.circle(img, ct, 2, (0, 255, 0), thickness=-1, lineType=cv2.FILLED)
bbox = np.array(bbox, dtype=np.int32).tolist()
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
keypoints = ann["keypoints"]
keypoints = np.array(keypoints, dtype=np.int32).reshape(_NUM_JOINTS, 3).tolist()
left_part = [0, 1, 3, 5, 7, 9, 11, 13, 15]
right_part = [0, 2, 4, 6, 8, 10, 12, 14, 16]
for pair in data_cfg.edges:
partA = pair[0]
partB = pair[1]
if partA in left_part and partB in left_part:
color = (255, 0, 0)
elif partA in right_part and partB in right_part:
color = (0, 0, 255)
else:
color = (139, 0, 255)
p_a = tuple(keypoints[partA][:2])
p_b = tuple(keypoints[partB][:2])
mask_a = keypoints[partA][2]
mask_b = keypoints[partB][2]
if (p_a[0] >= 0 and p_a[0] < w and p_a[1] >= 0 and p_a[1] < h and
p_b[0] >= 0 and p_b[0] < w and p_b[1] >= 0 and p_b[1] < h and
mask_a * mask_b > 0):
cv2.line(img, p_a, p_b, color, 2)
cv2.circle(img, p_a, 3, color, thickness=-1, lineType=cv2.FILLED)
cv2.circle(img, p_b, 3, color, thickness=-1, lineType=cv2.FILLED)
if annos and "image_id" in annos[0]:
img_id = annos[0]["image_id"]
else:
img_id = random.randint(0, 9999999)
if name is None:
image_name = "cv_image_" + str(img_id) + ".png"
else:
image_name = "cv_image_" + str(img_id) + name + ".png"
cv2.imwrite("{}/{}".format(save_path, image_name), img)

@ -0,0 +1,202 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Train CenterNet and get network model files(.ckpt)
"""
import os
import argparse
import mindspore.communication.management as D
from mindspore.communication.management import get_rank
from mindspore import context
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Adam
from mindspore import log as logger
from mindspore.common import set_seed
from mindspore.profiler import Profiler
from src.dataset import COCOHP
from src import CenterNetMultiPoseLossCell, CenterNetWithLossScaleCell
from src.utils import LossCallBack, CenterNetPolynomialDecayLR, CenterNetMultiEpochsDecayLR
from src.config import dataset_config, net_config, train_config
_current_dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description='CenterNet training')
parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
help="Run distribute, default is false.")
parser.add_argument("--need_profiler", type=str, default="false", choices=["true", "false"],
help="Profiling to parsing runtime info, default is false.")
parser.add_argument("--profiler_path", type=str, default=" ", help="The path to save profiling data")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1,"
"i.e. run all steps according to epoch number.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
help="Enable save checkpoint, default is true.")
parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
help="Enable shuffle for dataset, default is true.")
parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
help="Enable data sink, default is true.")
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, default is 1000.")
parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
parser.add_argument("--mindrecord_dir", type=str, default="",
help="Mindrecord files directory. If is empty, mindrecord format files will be generated"
"based on the original dataset and annotation information. If mindrecord_dir isn't empty,"
"mindrecord_dir will be used inplace of data_dir and anno_path.")
parser.add_argument("--data_dir", type=str, default="", help="Dataset directory, "
"the absolute image path is joined by the data_dir "
"and the relative path in anno_path")
parser.add_argument("--visual_image", type=str, default="false", help="Visulize the ground truth and predicted image")
parser.add_argument("--save_result_dir", type=str, default="", help="The path to save the predict results")
args_opt = parser.parse_args()
def _set_parallel_all_reduce_split():
"""set centernet all_reduce fusion split"""
if net_config.last_level == 5:
context.set_auto_parallel_context(all_reduce_fusion_config=[16, 56, 96, 136, 175])
elif net_config.last_level == 6:
context.set_auto_parallel_context(all_reduce_fusion_config=[18, 59, 100, 141, 182])
else:
raise ValueError("The total num of allreduced grads for last level = {} is unknown,"
"please re-split after known the true value".format(net_config.last_level))
def _get_params_groups(network, optimizer):
"""
Get param groups
"""
params = network.trainable_params()
decay_params = list(filter(lambda x: not optimizer.decay_filter(x), params))
other_params = list(filter(optimizer.decay_filter, params))
group_params = [{'params': decay_params, 'weight_decay': optimizer.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
return group_params
def _get_optimizer(network, dataset_size):
"""get optimizer, only support Adam right now."""
if train_config.optimizer == 'Adam':
group_params = _get_params_groups(network, train_config.Adam)
if train_config.lr_schedule == "PolyDecay":
lr_schedule = CenterNetPolynomialDecayLR(learning_rate=train_config.PolyDecay.learning_rate,
end_learning_rate=train_config.PolyDecay.end_learning_rate,
warmup_steps=train_config.PolyDecay.warmup_steps,
decay_steps=args_opt.train_steps,
power=train_config.PolyDecay.power)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.PolyDecay.eps, loss_scale=1.0)
elif train_config.lr_schedule == "MultiDecay":
multi_epochs = train_config.MultiDecay.multi_epochs
if not isinstance(multi_epochs, (list, tuple)):
raise TypeError("multi_epochs must be list or tuple.")
if not multi_epochs:
multi_epochs = [args_opt.epoch_size]
lr_schedule = CenterNetMultiEpochsDecayLR(learning_rate=train_config.MultiDecay.learning_rate,
warmup_steps=train_config.MultiDecay.warmup_steps,
multi_epochs=multi_epochs,
steps_per_epoch=dataset_size,
factor=train_config.MultiDecay.factor)
optimizer = Adam(group_params, learning_rate=lr_schedule, eps=train_config.MultiDecay.eps, loss_scale=1.0)
else:
raise ValueError("Don't support lr_schedule {}, only support [PolynormialDecay, MultiEpochDecay]".
format(train_config.optimizer))
else:
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, Adam]".
format(train_config.optimizer))
return optimizer
def train():
"""training CenterNet"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_auto_mixed_precision=False)
context.set_context(reserve_class_name_in_scope=False)
context.set_context(save_graphs=False)
ckpt_save_dir = args_opt.save_checkpoint_path
if args_opt.distribute == "true":
D.init()
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
_set_parallel_all_reduce_split()
else:
rank = 0
device_num = 1
num_workers = device_num * 8
# Start create dataset!
# mindrecord files will be generated at args_opt.mindrecord_dir such as centernet.mindrecord0, 1, ... file_num.
logger.info("Begin creating dataset for CenterNet")
prefix = "coco_hp.train.mind"
coco = COCOHP(args_opt.data_dir, dataset_config, net_config, run_mode="train")
coco.init(enable_visual_image=(args_opt.visual_image == "true"), save_path=args_opt.save_result_dir)
dataset = coco.create_train_dataset(args_opt.mindrecord_dir, prefix, batch_size=train_config.batch_size,
device_num=device_num, rank=rank, num_parallel_workers=num_workers,
do_shuffle=args_opt.do_shuffle == 'true')
dataset_size = dataset.get_dataset_size()
logger.info("Create dataset done!")
net_with_loss = CenterNetMultiPoseLossCell(net_config)
new_repeat_count = args_opt.epoch_size * dataset_size // args_opt.data_sink_steps
if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
else:
args_opt.train_steps = args_opt.epoch_size * dataset_size
logger.info("train steps: {}".format(args_opt.train_steps))
optimizer = _get_optimizer(net_with_loss, dataset_size)
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(dataset_size)]
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_centernet',
directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.load_checkpoint_path:
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
load_param_into_net(net_with_loss, param_dict)
net_with_grads = CenterNetWithLossScaleCell(net_with_loss, optimizer=optimizer,
sens=train_config.loss_scale_value)
model = Model(net_with_grads)
model.train(new_repeat_count, dataset, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"),
sink_size=args_opt.data_sink_steps)
if __name__ == '__main__':
if args_opt.need_profiler == "true":
profiler = Profiler(output_path=args_opt.profiler_path)
set_seed(0)
train()
if args_opt.need_profiler == "true":
profiler.analyse()
Loading…
Cancel
Save