You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/cv/centerface/test.py

161 lines
6.9 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Test centerface example
"""
import os
import time
import argparse
import datetime
import scipy.io as sio
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.utils import get_logger
from src.var_init import default_recurisive_init
from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
from src.config import ConfigCenterface
from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
from dependency.evaluate.eval import evaluation
dev_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
device_target="Ascend", save_graphs=False, device_id=dev_id)
parser = argparse.ArgumentParser('mindspore coco training')
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--test_model', type=str, default='', help='test model dir')
parser.add_argument('--ground_truth_mat', type=str, default='', help='ground_truth, mat type')
parser.add_argument('--save_dir', type=str, default='', help='save_path for evaluate')
parser.add_argument('--ground_truth_path', type=str, default='', help='ground_truth path, contain all mat file')
parser.add_argument('--eval', type=int, default=0, help='if do eval after test')
parser.add_argument('--eval_script_path', type=str, default='', help='evaluate script path')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
parser.add_argument('--ckpt_name', type=str, default="", help='input model name')
parser.add_argument('--device_num', type=int, default=1, help='device num for testing')
parser.add_argument('--steps_per_epoch', type=int, default=198, help='steps for each epoch')
parser.add_argument('--start', type=int, default=0, help='start loop number, used to calculate first epoch number')
parser.add_argument('--end', type=int, default=18, help='end loop number, used to calculate last epoch number')
args, _ = parser.parse_known_args()
if __name__ == "__main__":
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
args.logger.save_args(args)
if args.ckpt_name != "":
args.start = 0
args.end = 1
for loop in range(args.start, args.end, 1):
network = CenterfaceMobilev2()
default_recurisive_init(network)
if args.ckpt_name == "":
ckpt_num = loop * args.device_num + args.rank + 1
ckpt_name = "0-" + str(ckpt_num) + "_" + str(args.steps_per_epoch * ckpt_num) + ".ckpt"
else:
ckpt_name = args.ckpt_name
test_model = args.test_model + ckpt_name
if not test_model:
args.logger.info('load_model {} none'.format(test_model))
continue
if os.path.isfile(test_model):
param_dict = load_checkpoint(test_model)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
continue
elif key.startswith('centerface_network.'):
param_dict_new[key[19:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load_model {} success'.format(test_model))
else:
args.logger.info('{} not exists or not a pre-trained file'.format(test_model))
continue
train_network_type_nms = 1 # default with num
if train_network_type_nms:
network = CenterFaceWithNms(network)
args.logger.info('train network type with nms')
network.set_train(False)
args.logger.info('finish get network')
config = ConfigCenterface()
# test network -----------
start = time.time()
ground_truth_mat = sio.loadmat(args.ground_truth_mat)
event_list = ground_truth_mat['event_list']
file_list = ground_truth_mat['file_list']
if args.ckpt_name == "":
save_path = args.save_dir + str(ckpt_num) + '/'
else:
save_path = args.save_dir+ '/'
detector = CenterFaceDetector(config, network)
for index, event in enumerate(event_list):
file_list_item = file_list[index][0]
im_dir = event[0][0]
if not os.path.exists(save_path + im_dir):
os.makedirs(save_path + im_dir)
args.logger.info('save_path + im_dir={}'.format(save_path + im_dir))
for num, file in enumerate(file_list_item):
im_name = file[0][0]
zip_name = '%s/%s.jpg' % (im_dir, im_name)
img_path = os.path.join(args.data_dir, zip_name)
args.logger.info('img_path={}'.format(img_path))
dets = detector.run(img_path)['results']
f = open(save_path + im_dir + '/' + im_name + '.txt', 'w')
f.write('{:s}\n'.format('%s/%s.jpg' % (im_dir, im_name)))
f.write('{:d}\n'.format(len(dets)))
for b in dets[1]:
x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s))
f.close()
args.logger.info('event:{}, num:{}'.format(index + 1, num + 1))
end = time.time()
args.logger.info("============num {} time {}".format(num, (end-start)*1000))
start = end
if args.eval:
args.logger.info('==========start eval===============')
args.logger.info("test output path = {}".format(save_path))
if os.path.isdir(save_path):
evaluation(save_path, args.ground_truth_path)
else:
args.logger.info('no test output path')
args.logger.info('==========end eval===============')
if args.ckpt_name != "":
break
args.logger.info('==========end testing===============')