diff --git a/model_zoo/official/cv/centerface/README.MD b/model_zoo/official/cv/centerface/README.MD new file mode 100644 index 0000000000..651736a739 --- /dev/null +++ b/model_zoo/official/cv/centerface/README.MD @@ -0,0 +1,502 @@ +# Contents + +- [CenterFace Description](#CenterFace-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Testing Process](#testing-process) + - [Evaluation](#testing) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) + - [Convert Process](#convert-process) + - [Convert](#convert) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) + - [Inference Performance](#inference-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + + +# [CenterFace Description](#contents) + +CenterFace is a practical anchor-free face detection and alignment method for edge devices, we support training and evaluation on Ascend910. + +Face detection and alignment in unconstrained environment is always deployed on edge devices which have limited memory storage and low computing power. +CenterFace proposes a one-stage method to simultaneously predict facial box and landmark location with real-time speed and high accuracy. + +[Paper](https://arxiv.org/ftp/arxiv/papers/1911/1911.03599.pdf): CenterFace: Joint Face Detection and Alignment Using Face as Point. +Xu, Yuanyuan(Huaqiao University) and Yan, Wan(StarClouds) and Sun, Haixin(Xiamen University) +and Yang, Genke(Shanghai Jiaotong University) and Luo, Jiliang(Huaqiao University) + +# [Model Architecture](#contents) + +CenterFace uses mobilenet_v2 as pretrained backbone, add 4 layer fpn, with four head. +Four loss is presented, total loss is their weighted mean. + +# [Dataset](#contents) + +Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. + +Dataset support: [WiderFace] or datasetd with the same format as WiderFace +Annotation support: [WiderFace] or annotation as the same format as WiderFace + +- The directory structure is as follows, the name of directory and file is user define: + ``` + ├── dataset + ├── centerface + ├── annotations + │ ├─ train.json + │ └─ val.json + ├─ images + │ ├─ train + │ │ └─images + │ │ ├─class1_image_folder + │ │ ├─ ... + │ │ └─classn_image_folder + │ └─ val + │ └─images + │ ├─class1_image_folder + │ ├─ ... + │ └─classn_image_folder + └─ ground_truth + ├─val.mat + ├─ ... + └─xxx.mat + ``` +we suggest user to use WiderFace dataset to experience our model, +other datasets need to use the same format as WiderFace. + +# [Environment Requirements](#contents) + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](https://cmc-szv.clouddragon.huawei.com/cmcversion/index/search?searchKey=Do-MindSpore%20V100R001C00B622) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + +# [Quick Start](#contents) + +After installing MindSpore via the official website, you can start training and evaluation as follows: + +step1: prepare pretrained model: train a mobilenet_v2 model by mindspore or use the script below: +```python +#CenterFace need a pretrained mobilenet_v2 model: +# mobilenet_v2_key.ckpt is a model with all value zero, we need the key/cell/module name for this model. +# you must first use this script to convert your mobilenet_v2 pytorch model to mindspore model as a pretrain model. +# The key/cell/module name must as follow, otherwise you need to modify "name_map" function: +# --mindspore: as the same as mobilenet_v2_key.ckpt +# --pytorch: same as official pytorch model(e.g., official mobilenet_v2-b0353104.pth) +python torch_to_ms_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt +``` +step2: prepare user rank_table +```python +# user can use your own rank table file +# or use the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate rank table file +# e.g., python hccl_tools.py --device_num "[0,8)" +python hccl_tools.py --device_num "[0,8)" +``` +step3: train +```python +cd scripts; +# prepare data_path, use symbolic link +ln -sf [USE_DATA_DIR] dataset +# check you dir to make sure your datas are in the right path +ls ./dataset/centerface # data path +ls ./dataset/centerface/annotations/train.json # annot_path +ls ./dataset/centerface/images/train/images # img_dir +``` +```python +# enter script dir, train CenterFace +sh train_distribute.sh +# after training +mkdir ./model +cp device0/outputs/*/*.ckpt ./model # cp model to [MODEL_PATH] +``` +step4: test +```python +# test CenterFace preparing +cd ../dependency/centernet/src/lib/external; +python setup.py install; +make; +cd -; #cd ../../../../../scripts; +cd ../dependency/evaluate; +python setup.py install; # used for eval +cd -; #cd ../../scripts; +mkdir ./output +mkdir ./output/centerface +# check you dir to make sure your datas are in the right path +ls ./dataset/images/val/images/ # data path +ls ./dataset/centerface/ground_truth/val.mat # annot_path +``` +```python +# test CenterFace +sh test_distribute.sh +``` +step5: eval +```python +# after test, eval CenterFace, get MAP +# cd ../dependency/evaluate; +# python setup.py install; +# cd -; #cd ../../scripts; +sh eval_all.sh +``` + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +``` +├── cv + ├── centerface + ├── train.py // training scripts + ├── test.py // testing training outputs + ├── export.py // convert mindspore model to air model + ├── README.md // descriptions about CenterFace + ├── scripts + │ ├──eval.sh // evaluate a single testing result + │ ├──eval_all.sh // choose a range of testing results to evaluate + │ ├──test.sh // testing a single model + │ ├──test_distribute.sh // testing a range of models + │ ├──test_and_eval.sh // test then evaluate a single model + │ ├──train_standalone.sh // train in ascend with single npu + │ ├──train_distribute.sh // train in ascend with multi npu + ├── src + │ ├──__init__.py + │ ├──centerface.py // centerface networks, training entry + │ ├──dataset.py // generate dataloader and data processing entry + │ ├──config.py // centerface unique configs + │ ├──losses.py // losses for centerface + │ ├──lr_scheduler.py // learning rate scheduler + │ ├──mobile_v2.py // modified mobilenet_v2 backbone + │ ├──utils.py // auxiliary functions for train, to log and preload + │ ├──var_init.py // weight initilization + │ ├──convert_weight_mobilenetv2.py // convert pretrained backbone to mindspore + │ ├──convert_weight.py // CenterFace model convert to mindspore + └── dependency // third party codes: MIT License + ├──extd // training dependency: data augmentation + │ ├──utils + │ │ └──augmentations.py // data anchor sample of PyramidBox to generate small images + ├──evaluate // evaluate dependency + │ ├──box_overlaps.pyx // box overlaps + │ ├──setup.py // setupfile for box_overlaps.pyx + │ ├──eval.py // evaluate testing results + └──centernet // modified from 'centernet' + └──src + └──lib + ├──datasets + │ ├──dataset // train dataset core + │ │ ├──coco_hp.py // read and formatting data + │ ├──sample + │ │ └──multi_pose.py // core for data processing + ├──detectors // test core, including running, pre-processing and post-processing + │ ├──base_detector.py // user can add your own test core; for example, use pytorch or tf for pre/post processing + ├──external // test dependency + │ ├──__init__.py + │ ├──Makefile // makefile for nms + │ ├──nms.pyx // use soft_nms + │ ├──setup.py // setupfile for nms.pyx + └──utils + └──image.py // image processing functions +``` + +## [Script Parameters](#contents) +1. train scripts parameters +the command is: python train.py [train parameters] +Major parameters train.py as follows: +```python +--lr: learning rate +--per_batch_size: batch size on each device +--is_distributed: multi-device or not +--t_max: for cosine lr_scheduler +--max_epoch: training epochs +--warmup_epochs: warmup_epochs, not needed for adam, needed for sgd +--lr scheduler: learning rate scheduler, default is multistep +--lr_epochs: decrease lr steps +--lr_gamma: decrease lr by a factor +--weight_decay: weight decay +--loss_scale: mix precision training +--pretrained_backbone: pretrained mobilenet_v2 model path +--data_dir: data dir +--annot_path: annotations path +--img_dir: img dir in data_dir +``` +2. centerface unique configs: in config.py; not recommend user to change + +3. test scripts parameters: +the command is: python test.py [test parameters] +Major parameters test.py as follows: +```python +test_script_path: test.py path; +--is_distributed: multi-device or not +--data_dir: img dir +--test_model: test model dir +--ground_truth_mat: ground_truth file, mat type +--save_dir: save_path for evaluate +--rank: use device id +--ckpt_name: test model name +# blow are used for calculate ckpt/model name +# model/ckpt name is "0-" + str(ckpt_num) + "_" + str(steps_per_epoch*ckpt_num) + ".ckpt"; +# ckpt_num is epoch number, can be calculated by device_num +# detail can be found in "test.py" +# if ckpt is specified not need below 4 parameter +--device_num: training device number +--steps_per_epoch: steps for each epoch +--start: start loop number, used to calculate first epoch number +--end: end loop number, used to calculate last epoch number +``` + +4. eval scripts parameters: +the command is: python eval.py [pred] [gt] +Major parameters eval.py as follows: +```python +--pred: pred path, test output test.py->[--save_dir] +--gt: ground truth path +``` + +## [Training Process](#contents) + +### Training + +'task_set' is important for multi-npu train to get higher speed +--task_set: 0, not task_set; 1 task_set; +--task_set_core: task_set core number, most time = cpu number/nproc_per_node + +step1: user need train a mobilenet_v2 model by mindspore or use the script below: +```python +python torch_to_ms_mobilenetv2.py --ckpt_fn=./mobilenet_v2_key.ckpt --pt_fn=./mobilenet_v2-b0353104.pth --out_ckpt_fn=./mobilenet_v2.ckpt +``` +step2: prepare user rank_table +```python +# user can use your own rank table file +# or use the [hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools) to generate rank table file +# e.g., python hccl_tools.py --device_num "[0,8)" +python hccl_tools.py --device_num "[0,8)" +``` +step3: train +- Single device +```python +# enter script dir, train CenterFace +cd scripts +# you need to change the parameter in train_standalone.sh +# or use symbolic link as quick start +# or use the command as follow: +# USE_DEVICE_ID: your device +# PRETRAINED_BACKBONE: your pretrained model path +# DATASET: dataset path +# ANNOTATIONS: annotation path +# images: img_dir in dataset path +sh train_standalone.sh [USE_DEVICE_ID] [PRETRAINED_BACKBONE] [DATASET] [ANNOTATIONS] [IMAGES] +# after training +cp device0/outputs/*/*.ckpt [MODEL_PATH] +``` +- multi-device (recommended) +```python +# enter script dir, train CenterFace +cd scripts; +# you need to change the parameter in train_distribute.sh +# or use symbolic link as quick start +# or use the command as follow, most are the same as train_standalone.sh, the different is RANK_TABLE +# RANK_TABLE: for multi-device only, from generate_rank_table.py or user writing +sh train_distribute.sh [RANK_TABLE] [PRETRAINED_BACKBONE] [DATASET] [ANNOTATIONS] [IMAGES] +# after training +cp device0/outputs/*/*.ckpt [MODEL_PATH] +``` +After training with 8 device, the loss value will be achieved as follows: +```python +# grep "loss is " device0/xxx.log +# epoch: 1 step: 1, loss is greater than 500 and less than 5000 +2020-09-24 19:00:53,550:INFO:epoch:1, iter:0, average_loss:loss:1148.415649, loss:1148.4156494140625, overflow:False, loss_scale:1024.0 +[WARNING] DEBUG(51499,python):2020-09-24-19:00:53.590.008 [mindspore/ccsrc/debug/dump_proto.cc:218] SetValueToProto] Unsupported type UInt +2020-09-24 19:00:53,784:INFO:epoch:1, iter:1, average_loss:loss:798.286713, loss:448.15777587890625, overflow:False, loss_scale:1024.0 +... +2020-09-24 19:01:58,095:INFO:epoch:2, iter:197, average_loss:loss:1.942609, loss:1.5492267608642578, overflow:False, loss_scale:1024.0 +2020-09-24 19:01:58,501:INFO:epoch[2], loss:1.942609, 477.97 imgs/sec, lr:0.004000000189989805 +2020-09-24 19:01:58,502:INFO:==========end epoch=============== +2020-09-24 19:02:00,780:INFO:epoch:3, iter:0, average_loss:loss:2.107658, loss:2.1076583862304688, overflow:False, loss_scale:1024.0 +... +# epoch: 140 average loss is greater than 0.3 and less than 1.5: +2020-09-24 20:19:16,255:INFO:epoch:140, iter:196, average_loss:loss:0.906300, loss:1.1071504354476929, overflow:False, loss_scale:1024.0 +2020-09-24 20:19:16,347:INFO:epoch:140, iter:197, average_loss:loss:0.904684, loss:0.586264967918396, overflow:False, loss_scale:1024.0 +2020-09-24 20:19:16,747:INFO:epoch[140], loss:0.904684, 480.10 imgs/sec, lr:3.9999998989515007e-05 +2020-09-24 20:19:16,748:INFO:==========end epoch=============== +2020-09-24 20:19:16,748:INFO:==========end training=============== +``` +The model checkpoint will be saved in the scripts/device0/output/xxx/xxx.ckpt + +## [Testing Process](#contents) + +### Testing + +```python +# after train, prepare for test CenterFace +cd scripts; +cd ../dependency/centernet/src/lib/external; +python setup.py install; +make; +cd ../../../scripts; +mkdir [SAVE_PATH] +``` +1. test a single ckpt file +```python +# you need to change the parameter in test.sh +# or use symbolic link as quick start +# or use the command as follow: +# MODEL_PATH: ckpt path saved during training +# DATASET: img dir +# GROUND_TRUTH_MAT: ground_truth file, mat type +# SAVE_PATH: save_path for evaluate +# DEVICE_ID: use device id +# CKPT: test model name +sh test.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_ID] [CKPT] +``` +2. test many out ckpt for user to choose the best one +```python +# you need to change the parameter in test.sh +# or use symbolic link as quick start +# or use the command as follow, most are the same as test.sh, the different are: +# DEVICE_NUM: training device number +# STEPS_PER_EPOCH: steps for each epoch +# START: start loop number, used to calculate first epoch number +# END: end loop number, used to calculate last epoch number +sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_NUM] [STEPS_PER_EPOCH] [START] [END] +``` +After testing, you can find many txt file save the box information and scores, +open it you can see: +```python +646.3 189.1 42.1 51.8 0.747 # left top hight weight score +157.4 408.6 43.1 54.1 0.667 +120.3 212.4 38.7 42.8 0.650 +... +``` +## [Evaluation Process](#contents) + +### Evaluation + +```python +# after test, prepare for eval CenterFace, get MAP +cd ../dependency/evaluate; +python setup.py install; +cd ../../../scripts; +``` +1. eval a single testing output +```python +# you need to change the parameter in eval.sh +# default eval the ckpt saved in ./scripts/output/centerface/999 +sh eval.sh +``` +2. eval many testing output for user to choose the best one +```python +# you need to change the parameter in eval_all.sh +# default eval the ckpt saved in ./scripts/output/centerface/[89-140] +sh eval_all.sh +``` +3. test+eval +```python +# you need to change the parameter in test_and_eval.sh +# or use symbolic link as quick start, default eval the ckpt saved in ./scripts/output/centerface/999 +# or use the command as follow, most are the same as test.sh, the different are: +# GROUND_TRUTH_PATH: ground truth path +sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [CKPT] [GROUND_TRUTH_PATH] +``` +you can see the MAP below by eval.sh +``` +(ci3.7) [root@bms-aiserver scripts]# ./eval.sh +start eval +==================== Results = ==================== ./scripts/output/centerface/999 +Easy Val AP: 0.923914407045363 +Medium Val AP: 0.9166100571371586 +Hard Val AP: 0.7810750535799462 +================================================= +end eval +``` + +you can see the MAP below by eval_all.sh +``` +(ci3.7) [root@bms-aiserver scripts]# ./eval_all.sh +==================== Results = ==================== ./scripts/output/centerface/89 +Easy Val AP: 0.8884892849068273 +Medium Val AP: 0.8928813452811216 +Hard Val AP: 0.7721131614294564 +================================================= +==================== Results = ==================== ./scripts/output/centerface/90 +Easy Val AP: 0.8836073914165545 +Medium Val AP: 0.8875938506473486 +Hard Val AP: 0.775956751740446 +... +==================== Results = ==================== ./scripts/output/centerface/125 +Easy Val AP: 0.923914407045363 +Medium Val AP: 0.9166100571371586 +Hard Val AP: 0.7810750535799462 +================================================= +==================== Results = ==================== ./scripts/output/centerface/126 +Easy Val AP: 0.9218741197149122 +Medium Val AP: 0.9151860193570651 +Hard Val AP: 0.7825645670331809 +... +==================== Results = ==================== ./scripts/output/centerface/140 +Easy Val AP: 0.9250715236965638 +Medium Val AP: 0.9170429723233877 +Hard Val AP: 0.7822182013830674 +================================================= +``` +## [Convert Process](#contents) + +### Convert +If you want to infer the network on Ascend 310, you should convert the model to AIR: + +```python +python export.py [BATCH_SIZE] [PRETRAINED_BACKBONE] +``` + +# [Model Description](#contents) + +## [Performance](#contents) + +### Evaluation Performance +CenterFace on 13K images(The annotation and data format must be the same as widerFace) + +| Parameters | CenterFace | +| -------------------------- | ----------------------------------------------------------- | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G | +| uploaded Date | 10/29/2020 (month/day/year) | +| MindSpore Version | 1.0.0 | +| Dataset | 13K images | +| Training Parameters | epoch=140, steps=198 * epoch, batch_size = 8, lr=0.004 | +| Optimizer | Adam | +| Loss Function | Focal Loss, L1 Loss, Smooth L1 Loss | +| outputs | heatmaps | +| Loss | 0.3-1.5, average loss for last epoch is in 0.8-1.0 | +| Speed | 1p 65 img/s, 8p 475 img/s | +| Total time | train(8p) 1.1h, test 50min, eval 5-10min | +| Checkpoint for Fine tuning | 22M (.ckpt file) | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/centerface | + +### Inference Performance +CenterFace on 3.2K images(The annotation and data format must be the same as widerFace) + +| Parameters | CenterFace | +| -------------------------- | ----------------------------------------------------------- | +| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory, 755G | +| uploaded Date | 10/29/2020 (month/day/year) | +| MindSpore Version | 1.0.0 | +| Dataset | 3.2K images | +| batch_size | 1 | +| outputs | box position and sorces, and probability | +| Accuracy | Easy 92.2% Medium 91.5% Hard 78.2% (+-0.5%) | +| Model for inference | 22M (.ckpt file) | + +# [Description of Random Situation](#contents) + +In dataset.py, we set the seed inside ```create_dataset``` function. +In var_init.py, we set seed for weight initilization + +# [ModelZoo Homepage](#contents) + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/datasets/dataset/coco_hp.py b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/datasets/dataset/coco_hp.py new file mode 100644 index 0000000000..556f8b8fb7 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/datasets/dataset/coco_hp.py @@ -0,0 +1,89 @@ +""" +MIT License + +Copyright (c) 2019 Xingyi Zhou +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import pycocotools.coco as coco +import cv2 + +class CenterfaceDataset(): + """ + Centerface dataset definition. + """ + def __init__(self, config, split='train'): + self.split = split + self.config = config + self.max_objs = config.max_objs + self.img_dir = self.config.img_dir + self.annot_path = self.config.annot_path + + print('==> getting centerface key point {} data.'.format(split)) + self.coco = coco.COCO(self.annot_path) + image_ids = self.coco.getImgIds() + + if split == 'train': + self.images = [] + for img_id in image_ids: + idxs = self.coco.getAnnIds(imgIds=[img_id]) + if idxs: + self.images.append(img_id) + else: + self.images = image_ids + self.num_samples = len(self.images) + print('Loaded {} {} samples'.format(split, self.num_samples)) # Loaded train 12671 samples + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + (image, target) (tuple): target is index of the target class. + """ + img_id = self.images[index] + file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name'] + img_path = os.path.join(self.img_dir, file_name) + ann_ids = self.coco.getAnnIds(imgIds=[img_id]) + anns = self.coco.loadAnns(ids=ann_ids) + num_objs = len(anns) + if num_objs > self.max_objs: + num_objs = self.max_objs + anns = np.random.choice(anns, num_objs) + # dataType ERROR —— to_list + target = [] + for ann in anns: + tmp = [] + tmp.extend(ann['bbox']) + tmp.extend(ann['keypoints']) + target.append(tmp) + + img = cv2.imread(img_path) + return img, target + + def __len__(self): + return self.num_samples diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/datasets/sample/multi_pose.py b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/datasets/sample/multi_pose.py new file mode 100644 index 0000000000..5f4ff97eb3 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/datasets/sample/multi_pose.py @@ -0,0 +1,217 @@ +""" +MIT License + +Copyright (c) 2019 Xingyi Zhou +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import cv2 + +from dependency.centernet.src.lib.utils.image import color_aug +from dependency.centernet.src.lib.utils.image import get_affine_transform, affine_transform +from dependency.centernet.src.lib.utils.image import gaussian_radius, draw_umich_gaussian +from dependency.extd.utils.augmentations import anchor_crop_image_sampling + +def get_border(border, size): + """ + Get border + """ + i = 1 + while size - border // i <= border // i: # size > 2 * (border // i) + i *= 2 + return border // i + +def coco_box_to_bbox(box): + """ + (x1, y1, w, h) -> (x1, y1, x2, y2) + """ + bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]], dtype=np.float32) + return bbox + +def preprocess_train(image, target, config): + """ + Preprocess training data + """ + data_rng = np.random.RandomState(123) + eig_val = np.array([0.2141788, 0.01817699, 0.00341571], dtype=np.float32) + eig_vec = np.array([ + [-0.58752847, -0.69563484, 0.41340352], + [-0.5832747, 0.00994535, -0.81221408], + [-0.56089297, 0.71832671, 0.41158938] + ], dtype=np.float32) + mean = np.array([0.40789654, 0.44719302, 0.47026115], dtype=np.float32).reshape((1, 1, 3)) + std = np.array([0.28863828, 0.27408164, 0.27809835], dtype=np.float32).reshape((1, 1, 3)) + num_objs = len(target) + + anns = [] + for each in target: + ann = {} + ann['bbox'] = each[0:4] + ann['keypoints'] = each[4:] + anns.append(ann) + + cv2.setNumThreads(0) + img, anns = anchor_crop_image_sampling(image, anns) + + _, width = img.shape[0], img.shape[1] + c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) + s = max(img.shape[0], img.shape[1]) * 1.0 + rot = 0 + flipped = False + if config.rand_crop: + #s = s * np.random.choice(np.arange(0.8, 1.3, 0.05)) # for 768*768 or 800* 800 + s = s * np.random.choice(np.arange(0.6, 1.0, 0.05)) # for 512 * 512 + border = s * np.random.choice([0.1, 0.2, 0.25]) + w_border = get_border(border, img.shape[1]) # w > 2 * w_border + h_border = get_border(border, img.shape[0]) # h > 2 * h_border + c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border) + c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border) + else: + sf = config.scale + cf = config.shift + c[0] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf) + c[1] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf) + s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf) + if np.random.random() < config.rotate: + rf = config.rotate + rot = np.clip(np.random.randn() * rf, -rf * 2, rf * 2) + + if np.random.random() < config.flip: # opt.flip = 0.5 + flipped = True + img = img[:, ::-1, :] + c[0] = width - c[0] - 1 + + trans_input = get_affine_transform(c, s, rot, [config.input_res, config.input_res]) + inp = cv2.warpAffine(img, trans_input, (config.input_res, config.input_res), flags=cv2.INTER_LINEAR) + + inp = (inp.astype(np.float32) / 255.) + if config.color_aug: + color_aug(data_rng, inp, eig_val, eig_vec) + + inp = (inp - mean) / std + inp = inp.transpose(2, 0, 1) + + output_res = config.output_res + num_joints = config.num_joints + max_objs = config.max_objs + trans_output_rot = get_affine_transform(c, s, rot, [output_res, output_res]) + trans_output = get_affine_transform(c, s, 0, [output_res, output_res]) + + # map + hm = np.zeros((config.num_classes, output_res, output_res), dtype=np.float32) + hm_hp = np.zeros((num_joints, output_res, output_res), dtype=np.float32) + + wh = np.zeros((output_res, output_res, 2), dtype=np.float32) + reg = np.zeros((output_res, output_res, 2), dtype=np.float32) + ind = np.zeros((output_res, output_res), dtype=np.float32) # as float32, need no data_type change later + + reg_mask = np.zeros((max_objs), dtype=np.uint8) + wight_mask = np.zeros((output_res, output_res, 2), dtype=np.float32) + + kps = np.zeros((output_res, output_res, num_joints * 2), dtype=np.float32) + kps_mask = np.zeros((output_res, output_res, num_joints * 2), dtype=np.float32) + # + hp_offset = np.zeros((max_objs * num_joints, 2), dtype=np.float32) + hp_ind = np.zeros((max_objs * num_joints), dtype=np.int64) + hp_mask = np.zeros((max_objs * num_joints), dtype=np.int64) + + draw_gaussian = draw_umich_gaussian + + gt_det = [] + for k in range(num_objs): + ann = anns[k] + bbox = coco_box_to_bbox(ann['bbox']) # [x,y,w,h]--[x1,y1,x2,y2] + cls_id = 0 #int(ann['category_id']) - 1 + pts = np.array(ann['keypoints'], np.float32).reshape(num_joints, 3) # (x,y,0/1) + if flipped: + bbox[[0, 2]] = width - bbox[[2, 0]] - 1 + pts[:, 0] = width - pts[:, 0] - 1 + for e in config.flip_idx: # flip_idx = [[0, 1], [3, 4]] + pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy() + + bbox[:2] = affine_transform(bbox[:2], trans_output) # [0, 1] -- (x1, y1) + bbox[2:] = affine_transform(bbox[2:], trans_output) # [2, 3] -- (x2, y2) + bbox = np.clip(bbox, 0, output_res - 1) + h, w = bbox[3] - bbox[1], bbox[2] - bbox[0] + if (h > 0 and w > 0) or (rot != 0): + radius = gaussian_radius((math.ceil(h), math.ceil(w))) + radius = max(0, int(radius)) + ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32) + ct_int = ct.astype(np.int32) + + ind[ct_int[1], ct_int[0]] = 1.0 + wh[ct_int[1], ct_int[0], :] = np.log(1. * w / 4), np.log(1. * h / 4) + reg[ct_int[1], ct_int[0], :] = ct[0] - ct_int[0], ct[1] - ct_int[1] + + reg_mask[k] = 1.0 + wight_mask[ct_int[1], ct_int[0], 0] = 1 + wight_mask[ct_int[1], ct_int[0], 1] = 1 + + # if w*h <= 20: # can get what we want sometime, but unstable + # wight_mask[k] = 15 + if w*h <= 40: + wight_mask[ct_int[1], ct_int[0], 0] = 5 + wight_mask[ct_int[1], ct_int[0], 1] = 5 + if w*h <= 20: + wight_mask[ct_int[1], ct_int[0], 0] = 10 + wight_mask[ct_int[1], ct_int[0], 1] = 10 + if w*h <= 10: + wight_mask[ct_int[1], ct_int[0], 0] = 15 + wight_mask[ct_int[1], ct_int[0], 1] = 15 + if w*h <= 4: + wight_mask[ct_int[1], ct_int[0], 0] = 0.1 + wight_mask[ct_int[1], ct_int[0], 1] = 0.1 + + num_kpts = pts[:, 2].sum() + if num_kpts == 0: + hm[cls_id, ct_int[1], ct_int[0]] = 0.9999 + + hp_radius = gaussian_radius((math.ceil(h), math.ceil(w))) + hp_radius = max(0, int(hp_radius)) + for j in range(num_joints): + if pts[j, 2] > 0: + pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot) + if pts[j, 0] >= 0 and pts[j, 0] < output_res and pts[j, 1] >= 0 and pts[j, 1] < output_res: + kps[ct_int[1], ct_int[0], j * 2 : j * 2 + 2] = pts[j, :2] - ct_int + kps[ct_int[1], ct_int[0], j * 2 : j * 2 + 1] = kps[ct_int[1], ct_int[0], j * 2 : j * 2 + 1] / w + kps[ct_int[1], ct_int[0], j * 2 + 1: j * 2 + 2] = kps[ct_int[1], ct_int[0], + j * 2 + 1 : j * 2 + 2] / h + kps_mask[ct_int[1], ct_int[0], j * 2 : j * 2 + 2] = 1.0 + + pt_int = pts[j, :2].astype(np.int32) + hp_offset[k * num_joints + j] = pts[j, :2] - pt_int + hp_ind[k * num_joints + j] = pt_int[1] * output_res + pt_int[0] + hp_mask[k * num_joints + j] = 1 + + draw_gaussian(hm_hp[j], pt_int, hp_radius) + kps_mask[ct_int[1], ct_int[0], j * 2 : j * 2 + 2] = \ + 0.0 if ann['bbox'][2] * ann['bbox'][3] <= 8.0 else 1.0 + draw_gaussian(hm[cls_id], ct_int, radius) + gt_det.append([ct[0] - w / 2, ct[1] - h / 2, + ct[0] + w / 2, ct[1] + h / 2, 1] + + pts[:, :2].reshape(num_joints * 2).tolist() + [cls_id]) + + return inp, hm, reg_mask, ind, wh, wight_mask, reg, kps_mask, kps diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/detectors/base_detector.py b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/detectors/base_detector.py new file mode 100644 index 0000000000..0b46ddb4e0 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/detectors/base_detector.py @@ -0,0 +1,216 @@ +###modified based on centernet### +#MIT License +#Copyright (c) 2019 Xingyi Zhou +#All rights reserved. +"""Basic definition of detector""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 + +from mindspore import Tensor + +from dependency.centernet.src.lib.external.nms import soft_nms +from dependency.centernet.src.lib.utils.image import get_affine_transform, affine_transform + +def transform_preds(coords, center, scale, output_size): + """ + Transform target coords + """ + target_coords = np.zeros(coords.shape) + trans = get_affine_transform(center, scale, 0, output_size, inv=1) + for p in range(coords.shape[0]): + target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) + return target_coords + +def multi_pose_post_process(dets, c, s, h, w): + """ + Multi pose post process + dets_result: 4 + score:1 + kpoints:10 + class:1 = 16 + dets: batch x max_dets x 40 + return list of 39 in image coord + """ + ret = [] + for i in range(dets.shape[0]): + bbox = transform_preds(dets[i, :, :4].reshape(-1, 2), c[i], s[i], (w, h)) + pts = transform_preds(dets[i, :, 5:15].reshape(-1, 2), c[i], s[i], (w, h)) + top_preds = np.concatenate([bbox.reshape(-1, 4), dets[i, :, 4:5], pts.reshape(-1, 10)], + axis=1).astype(np.float32).tolist() + ret.append({np.ones(1, dtype=np.int32)[0]: top_preds}) + return ret + +class CenterFaceDetector(): + """ + Centerface detector + """ + def __init__(self, opt, model): + self.flip_idx = opt.flip_idx + + print('Creating model...') + self.model = model + + self.mean = np.array(opt.mean, dtype=np.float32).reshape((1, 1, 3)) + self.std = np.array(opt.std, dtype=np.float32).reshape((1, 1, 3)) + self.max_per_image = 100 + self.num_classes = opt.num_classes + self.scales = opt.test_scales + self.opt = opt + self.pause = False + + def pre_process(self, image, scale, meta=None): + """ + Preprocess method + """ + height, width = image.shape[0:2] + new_height = int(height * scale) + new_width = int(width * scale) + if self.opt.fix_res: # True + inp_height, inp_width = self.opt.input_h, self.opt.input_w + c = np.array([new_width / 2., new_height / 2.], dtype=np.float32) + s = max(height, width) * 1.0 + else: + inp_height = int(np.ceil(new_height / 32) * 32) + inp_width = int(np.ceil(new_width / 32) * 32) + c = np.array([new_width // 2, new_height // 2], dtype=np.float32) + s = np.array([inp_width, inp_height], dtype=np.float32) + + trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) + resized_image = cv2.resize(image, (new_width, new_height)) + inp_image = cv2.warpAffine( + resized_image, trans_input, (inp_width, inp_height), + flags=cv2.INTER_LINEAR) + inp_image = ((inp_image / 255. - self.mean) / self.std).astype(np.float32) + + images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, inp_width) + if self.opt.flip_test: + images = np.concatenate((images, images[:, :, :, ::-1]), axis=0) + + meta = {'c': c, 's': s, 'out_height': inp_height // self.opt.down_ratio, + 'out_width': inp_width // self.opt.down_ratio} + return images, meta + + def process(self, images): + """ + Process method + """ + images = Tensor(images) + # test with mindspore model + output_hm, output_wh, output_off, output_kps, topk_inds = self.model(images) + # Tensor to numpy + output_hm = output_hm.asnumpy().astype(np.float32) + output_wh = output_wh.asnumpy().astype(np.float32) + output_off = output_off.asnumpy().astype(np.float32) + output_kps = output_kps.asnumpy().astype(np.float32) + topk_inds = topk_inds.asnumpy().astype(np.long) + + reg = output_off if self.opt.reg_offset else None + + dets = self.centerface_decode(output_hm, output_wh, output_kps, reg=reg, opt_k=self.opt.K, topk_inds=topk_inds) + + return dets + + def post_process(self, dets, meta, scale=1): + """ + Post process process + """ + dets = dets.reshape(1, -1, dets.shape[2]) + dets = multi_pose_post_process( + dets.copy(), [meta['c']], [meta['s']], + meta['out_height'], meta['out_width']) + for j in range(1, self.num_classes + 1): + dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 15) + # import pdb; pdb.set_trace() + dets[0][j][:, :4] /= scale + dets[0][j][:, 5:] /= scale + return dets[0] + + def merge_outputs(self, detections): + """ + Merge detection outputs + """ + results = {} + results[1] = np.concatenate([detection[1] for detection in detections], axis=0).astype(np.float32) + if self.opt.nms or len(self.opt.test_scales) > 1: + soft_nms(results[1], Nt=0.5, method=2) + results[1] = results[1].tolist() + return results + + def run(self, image_or_path_or_tensor, meta=None): + """ + Run method + """ + pre_processed = False + if isinstance(image_or_path_or_tensor, np.ndarray): + image = image_or_path_or_tensor + elif isinstance(image_or_path_or_tensor, str): + image = cv2.imread(image_or_path_or_tensor) + else: + image = image_or_path_or_tensor['image'][0].numpy() + pre_processed_images = image_or_path_or_tensor + pre_processed = True + + detections = [] + for scale in self.scales: # [1] + if not pre_processed: + images, meta = self.pre_process(image, scale, meta) # --1: pre_process + else: + # import pdb; pdb.set_trace() + images = pre_processed_images['images'][scale][0] + meta = pre_processed_images['meta'][scale] + meta = {k: v.numpy()[0] for k, v in meta.items()} + + dets = self.process(images) # --2: process + + dets = self.post_process(dets, meta, scale) # box:4+score:1+kpoints:10+class:1=16 ## --3: post_process + + detections.append(dets) + + results = self.merge_outputs(detections) # --4: merge_outputs + return {'results': results} + + def centerface_decode(self, heat, wh, kps, reg=None, opt_k=100, topk_inds=None): + """ + Decode detection bbox + """ + batch, _, _, width = wh.shape + + num_joints = kps.shape[1] // 2 + + scores = heat + inds = topk_inds + ys_int = (topk_inds / width).astype(np.int32).astype(np.float32) + xs_int = (topk_inds % width).astype(np.int32).astype(np.float32) + + reg = reg.reshape(batch, 2, -1) + reg_tmp = np.zeros((batch, 2, opt_k), dtype=np.float32) + for i in range(batch): + reg_tmp[i, 0, :] = reg[i, 0, inds[i]] + reg_tmp[i, 1, :] = reg[i, 1, inds[i]] + reg = reg_tmp.transpose(0, 2, 1) + + if reg is not None: + xs = xs_int.reshape(batch, opt_k, 1) + reg[:, :, 0:1] + ys = ys_int.reshape(batch, opt_k, 1) + reg[:, :, 1:2] + else: + xs = xs_int.reshape(batch, opt_k, 1) + 0.5 + ys = ys_int.reshape(batch, opt_k, 1) + 0.5 + + wh = wh.reshape(batch, 2, -1) + wh_tmp = np.zeros((batch, 2, opt_k), dtype=np.float32) + for i in range(batch): + wh_tmp[i, 0, :] = wh[i, 0, inds[i]] + wh_tmp[i, 1, :] = wh[i, 1, inds[i]] + + wh = wh_tmp.transpose(0, 2, 1) + wh = np.exp(wh) * 4. + scores = scores.reshape(batch, opt_k, 1) + bboxes = np.concatenate([xs - wh[..., 0:1] / 2, ys - wh[..., 1:2] / 2, xs + wh[..., 0:1] / 2, + ys + wh[..., 1:2] / 2], axis=2) + + clses = np.zeros((batch, opt_k, 1), dtype=np.float32) + kps = np.zeros((batch, opt_k, num_joints * 2), dtype=np.float32) + detections = np.concatenate([bboxes, scores, kps, clses], axis=2) # box:4+score:1+kpoints:10+class:1=16 + return detections diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/__init__.py b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/nms.pyx b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/nms.pyx new file mode 100644 index 0000000000..6499102354 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/nms.pyx @@ -0,0 +1,391 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ross Girshick +# -------------------------------------------------------- + +# ---------------------------------------------------------- +# Soft-NMS: Improving Object Detection With One Line of Code +# Copyright (c) University of Maryland, College Park +# Licensed under The MIT License [see LICENSE for details] +# Written by Navaneeth Bodla and Bharat Singh +# ---------------------------------------------------------- + +import numpy as np +cimport numpy as np + +cdef inline np.float32_t max(np.float32_t a, np.float32_t b): + return a if a >= b else b + +cdef inline np.float32_t min(np.float32_t a, np.float32_t b): + return a if a <= b else b + +def nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh): + cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] + cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] + cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] + cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] + cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] + + cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) + cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] + + cdef int ndets = dets.shape[0] + cdef np.ndarray[np.int_t, ndim=1] suppressed = \ + np.zeros((ndets), dtype=np.int) + + # nominal indices + cdef int _i, _j + # sorted indices + cdef int i, j + # temp variables for box i's (the box currently under consideration) + cdef np.float32_t ix1, iy1, ix2, iy2, iarea + # variables for computing overlap with box j (lower scoring box) + cdef np.float32_t xx1, yy1, xx2, yy2 + cdef np.float32_t w, h + cdef np.float32_t inter, ovr + + keep = [] + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + keep.append(i) + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= thresh: + suppressed[j] = 1 + + return keep + +def soft_nms(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): + cdef unsigned int N = boxes.shape[0] + cdef float iw, ih, box_area + cdef float ua + cdef int pos = 0 + cdef float maxscore = 0 + cdef int maxpos = 0 + cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov + + for i in range(N): + maxscore = boxes[i, 4] + maxpos = i + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 4]: + maxscore = boxes[pos, 4] + maxpos = pos + pos = pos + 1 + + # add max box as a detection + boxes[i,0] = boxes[maxpos,0] + boxes[i,1] = boxes[maxpos,1] + boxes[i,2] = boxes[maxpos,2] + boxes[i,3] = boxes[maxpos,3] + boxes[i,4] = boxes[maxpos,4] + + # swap ith box with position of max box + boxes[maxpos,0] = tx1 + boxes[maxpos,1] = ty1 + boxes[maxpos,2] = tx2 + boxes[maxpos,3] = ty2 + boxes[maxpos,4] = ts + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # NMS iterations, note that N changes if detection boxes fall below threshold + while pos < N: + x1 = boxes[pos, 0] + y1 = boxes[pos, 1] + x2 = boxes[pos, 2] + y2 = boxes[pos, 3] + s = boxes[pos, 4] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + iw = (min(tx2, x2) - max(tx1, x1) + 1) + if iw > 0: + ih = (min(ty2, y2) - max(ty1, y1) + 1) + if ih > 0: + ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) + ov = iw * ih / ua #iou between max box and detection box + + if method == 1: # linear + if ov > Nt: + weight = 1 - ov + else: + weight = 1 + elif method == 2: # gaussian + weight = np.exp(-(ov * ov)/sigma) + else: # original NMS + if ov > Nt: + weight = 0 + else: + weight = 1 + + boxes[pos, 4] = weight*boxes[pos, 4] + + # if box score falls below threshold, discard the box by swapping with last box + # update N + if boxes[pos, 4] < threshold: + boxes[pos,0] = boxes[N-1, 0] + boxes[pos,1] = boxes[N-1, 1] + boxes[pos,2] = boxes[N-1, 2] + boxes[pos,3] = boxes[N-1, 3] + boxes[pos,4] = boxes[N-1, 4] + N = N - 1 + pos = pos - 1 + + pos = pos + 1 + + keep = [i for i in range(N)] + return keep + +def soft_nms_39(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0): + cdef unsigned int N = boxes.shape[0] + cdef float iw, ih, box_area + cdef float ua + cdef int pos = 0 + cdef float maxscore = 0 + cdef int maxpos = 0 + cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov + cdef float tmp + + for i in range(N): + maxscore = boxes[i, 4] + maxpos = i + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 4]: + maxscore = boxes[pos, 4] + maxpos = pos + pos = pos + 1 + + # add max box as a detection + boxes[i,0] = boxes[maxpos,0] + boxes[i,1] = boxes[maxpos,1] + boxes[i,2] = boxes[maxpos,2] + boxes[i,3] = boxes[maxpos,3] + boxes[i,4] = boxes[maxpos,4] + + # swap ith box with position of max box + boxes[maxpos,0] = tx1 + boxes[maxpos,1] = ty1 + boxes[maxpos,2] = tx2 + boxes[maxpos,3] = ty2 + boxes[maxpos,4] = ts + + for j in range(5, 39): + tmp = boxes[i, j] + boxes[i, j] = boxes[maxpos, j] + boxes[maxpos, j] = tmp + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # NMS iterations, note that N changes if detection boxes fall below threshold + while pos < N: + x1 = boxes[pos, 0] + y1 = boxes[pos, 1] + x2 = boxes[pos, 2] + y2 = boxes[pos, 3] + s = boxes[pos, 4] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + iw = (min(tx2, x2) - max(tx1, x1) + 1) + if iw > 0: + ih = (min(ty2, y2) - max(ty1, y1) + 1) + if ih > 0: + ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) + ov = iw * ih / ua #iou between max box and detection box + + if method == 1: # linear + if ov > Nt: + weight = 1 - ov + else: + weight = 1 + elif method == 2: # gaussian + weight = np.exp(-(ov * ov)/sigma) + else: # original NMS + if ov > Nt: + weight = 0 + else: + weight = 1 + + boxes[pos, 4] = weight*boxes[pos, 4] + + # if box score falls below threshold, discard the box by swapping with last box + # update N + if boxes[pos, 4] < threshold: + boxes[pos,0] = boxes[N-1, 0] + boxes[pos,1] = boxes[N-1, 1] + boxes[pos,2] = boxes[N-1, 2] + boxes[pos,3] = boxes[N-1, 3] + boxes[pos,4] = boxes[N-1, 4] + for j in range(5, 39): + tmp = boxes[pos, j] + boxes[pos, j] = boxes[N - 1, j] + boxes[N - 1, j] = tmp + N = N - 1 + pos = pos - 1 + + pos = pos + 1 + + keep = [i for i in range(N)] + return keep + +def soft_nms_merge(np.ndarray[float, ndim=2] boxes, float sigma=0.5, float Nt=0.3, float threshold=0.001, unsigned int method=0, float weight_exp=6): + cdef unsigned int N = boxes.shape[0] + cdef float iw, ih, box_area + cdef float ua + cdef int pos = 0 + cdef float maxscore = 0 + cdef int maxpos = 0 + cdef float x1,x2,y1,y2,tx1,tx2,ty1,ty2,ts,area,weight,ov + cdef float mx1,mx2,my1,my2,mts,mbs,mw + + for i in range(N): + maxscore = boxes[i, 4] + maxpos = i + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # get max box + while pos < N: + if maxscore < boxes[pos, 4]: + maxscore = boxes[pos, 4] + maxpos = pos + pos = pos + 1 + + # add max box as a detection + boxes[i,0] = boxes[maxpos,0] + boxes[i,1] = boxes[maxpos,1] + boxes[i,2] = boxes[maxpos,2] + boxes[i,3] = boxes[maxpos,3] + boxes[i,4] = boxes[maxpos,4] + + mx1 = boxes[i, 0] * boxes[i, 5] + my1 = boxes[i, 1] * boxes[i, 5] + mx2 = boxes[i, 2] * boxes[i, 6] + my2 = boxes[i, 3] * boxes[i, 6] + mts = boxes[i, 5] + mbs = boxes[i, 6] + + # swap ith box with position of max box + boxes[maxpos,0] = tx1 + boxes[maxpos,1] = ty1 + boxes[maxpos,2] = tx2 + boxes[maxpos,3] = ty2 + boxes[maxpos,4] = ts + + tx1 = boxes[i,0] + ty1 = boxes[i,1] + tx2 = boxes[i,2] + ty2 = boxes[i,3] + ts = boxes[i,4] + + pos = i + 1 + # NMS iterations, note that N changes if detection boxes fall below threshold + while pos < N: + x1 = boxes[pos, 0] + y1 = boxes[pos, 1] + x2 = boxes[pos, 2] + y2 = boxes[pos, 3] + s = boxes[pos, 4] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + iw = (min(tx2, x2) - max(tx1, x1) + 1) + if iw > 0: + ih = (min(ty2, y2) - max(ty1, y1) + 1) + if ih > 0: + ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) + ov = iw * ih / ua #iou between max box and detection box + + if method == 1: # linear + if ov > Nt: + weight = 1 - ov + else: + weight = 1 + elif method == 2: # gaussian + weight = np.exp(-(ov * ov)/sigma) + else: # original NMS + if ov > Nt: + weight = 0 + else: + weight = 1 + + mw = (1 - weight) ** weight_exp + mx1 = mx1 + boxes[pos, 0] * boxes[pos, 5] * mw + my1 = my1 + boxes[pos, 1] * boxes[pos, 5] * mw + mx2 = mx2 + boxes[pos, 2] * boxes[pos, 6] * mw + my2 = my2 + boxes[pos, 3] * boxes[pos, 6] * mw + mts = mts + boxes[pos, 5] * mw + mbs = mbs + boxes[pos, 6] * mw + + boxes[pos, 4] = weight*boxes[pos, 4] + + # if box score falls below threshold, discard the box by swapping with last box + # update N + if boxes[pos, 4] < threshold: + boxes[pos,0] = boxes[N-1, 0] + boxes[pos,1] = boxes[N-1, 1] + boxes[pos,2] = boxes[N-1, 2] + boxes[pos,3] = boxes[N-1, 3] + boxes[pos,4] = boxes[N-1, 4] + N = N - 1 + pos = pos - 1 + + pos = pos + 1 + + boxes[i, 0] = mx1 / mts + boxes[i, 1] = my1 / mts + boxes[i, 2] = mx2 / mbs + boxes[i, 3] = my2 / mbs + + keep = [i for i in range(N)] + return keep diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/setup.py b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/setup.py new file mode 100644 index 0000000000..aa41a6b616 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/external/setup.py @@ -0,0 +1,42 @@ +""" +MIT License + +Copyright (c) 2019 Xingyi Zhou +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +from distutils.core import setup +from distutils.extension import Extension +import numpy +from Cython.Build import cythonize + +extensions = [ + Extension( + "nms", + ["nms.pyx"], + extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] + ) +] + +setup( + name="coco", + ext_modules=cythonize(extensions), + include_dirs=[numpy.get_include()] +) diff --git a/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/utils/image.py b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/utils/image.py new file mode 100644 index 0000000000..ec960b0216 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/centernet/src/lib/utils/image.py @@ -0,0 +1,170 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Xingyi Zhou +# ------------------------------------------------------------------------------ +"""Image process""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import numpy as np +import cv2 + +def get_3rd_point(a, b): + """ + Get 3rd point + """ + direct = a - b + return b + np.array([-direct[1], direct[0]], dtype=np.float32) + +def get_dir(src_point, rot_rad): + """ + Get dir + """ + sn, cs = np.sin(rot_rad), np.cos(rot_rad) # (0, 1) + + src_result = [0, 0] + src_result[0] = src_point[0] * cs - src_point[1] * sn + src_result[1] = src_point[0] * sn + src_point[1] * cs + + return src_result + +def get_affine_transform(center, + scale, + rot, + output_size, + shift=np.array([0, 0], dtype=np.float32), + inv=0): + """ + Get affine transform + """ + if not isinstance(scale, np.ndarray) and not isinstance(scale, list): + scale = np.array([scale, scale], dtype=np.float32) + + scale_tmp = scale + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = get_dir([0, src_w * -0.5], rot_rad) + dst_dir = np.array([0, dst_w * -0.5], np.float32) + + src = np.zeros((3, 2), dtype=np.float32) + dst = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5], np.float32) + dst_dir + + src[2:, :] = get_3rd_point(src[0, :], src[1, :]) + dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + +def affine_transform(pt, t): + """ + Affine transform + """ + new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32).T + new_pt = np.dot(t, new_pt) + return new_pt[:2] + +def grayscale(image): + return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + +def lighting_(data_rng, image, alphastd, eigval, eigvec): + alpha = data_rng.normal(scale=alphastd, size=(3,)) + image += np.dot(eigvec, eigval * alpha) + +def blend_(alpha, image1, image2): + image1 *= alpha + image2 *= (1 - alpha) + image1 += image2 + +def saturation_(data_rng, image, gs, gs_mean, var): + gs_mean = gs_mean + alpha = 1. + data_rng.uniform(low=-var, high=var) + blend_(alpha, image, gs[:, :, None]) + +def brightness_(data_rng, image, gs, gs_mean, var): + gs = gs + gs_mean = gs_mean + alpha = 1. + data_rng.uniform(low=-var, high=var) + image *= alpha + +def contrast_(data_rng, image, gs, gs_mean, var): + gs = gs + alpha = 1. + data_rng.uniform(low=-var, high=var) + blend_(alpha, image, gs_mean) + +def color_aug(data_rng, image, eig_val, eig_vec): + functions = [brightness_, contrast_, saturation_] + random.shuffle(functions) + + gs = grayscale(image) + gs_mean = gs.mean() + for f in functions: + f(data_rng, image, gs, gs_mean, 0.4) + lighting_(data_rng, image, 0.1, eig_val, eig_vec) + +def gaussian_radius(det_size, min_overlap=0.7): + """ + Gaussian radius + """ + height, width = det_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + return min(r1, r2, r3) + +def gaussian2d(shape, sigma=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m+1, -n:n+1] + + h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + +def draw_umich_gaussian(heatmap, center, radius, k=1): + """ + Draw umich gaussian + """ + diameter = 2 * radius + 1 + gaussian = gaussian2d((diameter, diameter), sigma=diameter / 6) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + return heatmap diff --git a/model_zoo/official/cv/centerface/denpendency/evaluate/box_overlaps.pyx b/model_zoo/official/cv/centerface/denpendency/evaluate/box_overlaps.pyx new file mode 100644 index 0000000000..ad326ba1df --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/evaluate/box_overlaps.pyx @@ -0,0 +1,55 @@ +# -------------------------------------------------------- +# Fast R-CNN +# Copyright (c) 2015 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Sergey Karayev +# -------------------------------------------------------- + +cimport cython +import numpy as np +cimport numpy as np + +DTYPE = np.float +ctypedef np.float_t DTYPE_t + +def bbox_overlaps( + np.ndarray[DTYPE_t, ndim=2] boxes, + np.ndarray[DTYPE_t, ndim=2] query_boxes): + """ + Parameters + ---------- + boxes: (N, 4) ndarray of float + query_boxes: (K, 4) ndarray of float + Returns + ------- + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + cdef unsigned int N = boxes.shape[0] + cdef unsigned int K = query_boxes.shape[0] + cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) + cdef DTYPE_t iw, ih, box_area + cdef DTYPE_t ua + cdef unsigned int k, n + for k in range(K): + box_area = ( + (query_boxes[k, 2] - query_boxes[k, 0] + 1) * + (query_boxes[k, 3] - query_boxes[k, 1] + 1) + ) + for n in range(N): + iw = ( + min(boxes[n, 2], query_boxes[k, 2]) - + max(boxes[n, 0], query_boxes[k, 0]) + 1 + ) + if iw > 0: + ih = ( + min(boxes[n, 3], query_boxes[k, 3]) - + max(boxes[n, 1], query_boxes[k, 1]) + 1 + ) + if ih > 0: + ua = float( + (boxes[n, 2] - boxes[n, 0] + 1) * + (boxes[n, 3] - boxes[n, 1] + 1) + + box_area - iw * ih + ) + overlaps[n, k] = iw * ih / ua + return overlaps \ No newline at end of file diff --git a/model_zoo/official/cv/centerface/denpendency/evaluate/eval.py b/model_zoo/official/cv/centerface/denpendency/evaluate/eval.py new file mode 100644 index 0000000000..7c4f260daf --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/evaluate/eval.py @@ -0,0 +1,316 @@ +""" +WiderFace evaluation code +author: wondervictor +mail: tianhengcheng@gmail.com +copyright@wondervictor + +MIT License + +Copyright (c) 2018 Vic Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import division + +import os +import pickle +import argparse +import numpy as np +from scipy.io import loadmat +from bbox import bbox_overlaps + +def get_gt_boxes(gt_dir): + """ gt dir: (wider_face_val.mat, wider_easy_val.mat, wider_medium_val.mat, wider_hard_val.mat)""" + + gt_mat = loadmat(os.path.join(gt_dir, 'wider_face_val.mat')) # you own ground_truth name + hard_mat = loadmat(os.path.join(gt_dir, 'wider_hard_val.mat')) + medium_mat = loadmat(os.path.join(gt_dir, 'wider_medium_val.mat')) + easy_mat = loadmat(os.path.join(gt_dir, 'wider_easy_val.mat')) + + facebox_list = gt_mat['face_bbx_list'] + event_list = gt_mat['event_list'] + file_list = gt_mat['file_list'] + + hard_gt_list = hard_mat['gt_list'] + medium_gt_list = medium_mat['gt_list'] + easy_gt_list = easy_mat['gt_list'] + + return facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list + + +def get_gt_boxes_from_txt(gt_path, cache_dir): + """ + Get gt boxes from binary txt file. + """ + cache_file = os.path.join(cache_dir, 'gt_cache.pkl') + if os.path.exists(cache_file): + f = open(cache_file, 'rb') + boxes = pickle.load(f) + f.close() + return boxes + + f = open(gt_path, 'r') + state = 0 + lines = f.readlines() + lines = list(map(lambda x: x.rstrip('\r\n'), lines)) + boxes = {} + f.close() + current_boxes = [] + current_name = None + for line in lines: + if state == 0 and '--' in line: + state = 1 + current_name = line + continue + if state == 1: + state = 2 + continue + + if state == 2 and '--' in line: + state = 1 + boxes[current_name] = np.array(current_boxes).astype('float32') + current_name = line + current_boxes = [] + continue + + if state == 2: + box = [float(x) for x in line.split(' ')[:4]] + current_boxes.append(box) + continue + + f = open(cache_file, 'wb') + pickle.dump(boxes, f) + f.close() + return boxes + + +def read_pred_file(filepath): + + with open(filepath, 'r') as f: + lines = f.readlines() + img_file = lines[0].rstrip('\n\r') + lines = lines[2:] + + boxes = np.array(list(map(lambda x: [float(a) for a in x.rstrip('\r\n').split(' ')], lines))).astype('float') + return img_file.split('/')[-1], boxes + + +def get_preds(pred_dir): + """Get preds""" + events = os.listdir(pred_dir) + boxes = dict() + #pbar = tqdm.tqdm(events) + pbar = events + for event in pbar: + #pbar.set_description('Reading Predictions ') + event_dir = os.path.join(pred_dir, event) + event_images = os.listdir(event_dir) + current_event = dict() + for imgtxt in event_images: + imgname, box = read_pred_file(os.path.join(event_dir, imgtxt)) + current_event[imgname.rstrip('.jpg')] = box + boxes[event] = current_event + return boxes + + +def norm_score(pred_norm): + """ norm score + pred_norm {key: [[x1,y1,x2,y2,s]]} + """ + max_score = 0 + min_score = 1 + + for _, k in pred_norm.items(): + for _, v in k.items(): + if v.size == 0: + continue + min_v = np.min(v[:, -1]) + max_v = np.max(v[:, -1]) + max_score = max(max_v, max_score) + min_score = min(min_v, min_score) + + diff = max_score - min_score + for _, k in pred_norm.items(): + for _, v in k.items(): + if v.size == 0: + continue + v[:, -1] = (v[:, -1] - min_score)/diff + + +def image_eval(pred_eval, gt, ignore, iou_thresh): + """ single image evaluation + pred_eval: Nx5 + gt: Nx4 + ignore: + """ + pred_t = pred_eval.copy() + gt_t = gt.copy() + pred_recall = np.zeros(pred_t.shape[0]) + recall_list = np.zeros(gt_t.shape[0]) + proposal_list = np.ones(pred_t.shape[0]) + + pred_t[:, 2] = pred_t[:, 2] + pred_t[:, 0] + pred_t[:, 3] = pred_t[:, 3] + pred_t[:, 1] + gt_t[:, 2] = gt_t[:, 2] + gt_t[:, 0] + gt_t[:, 3] = gt_t[:, 3] + gt_t[:, 1] + + overlaps = bbox_overlaps(pred_t[:, :4], gt_t) + + for h in range(pred_t.shape[0]): + + gt_overlap = overlaps[h] + max_overlap, max_idx = gt_overlap.max(), gt_overlap.argmax() + if max_overlap >= iou_thresh: + if ignore[max_idx] == 0: + recall_list[max_idx] = -1 + proposal_list[h] = -1 + elif recall_list[max_idx] == 0: + recall_list[max_idx] = 1 + + r_keep_index = np.where(recall_list == 1)[0] + pred_recall[h] = len(r_keep_index) + return pred_recall, proposal_list + + +def img_pr_info(thresh_num, pred_info, proposal_list, pred_recall): + """ + Image pr info + """ + pr_info = np.zeros((thresh_num, 2)).astype('float') + for t in range(thresh_num): + + thresh = 1 - (t+1)/thresh_num + r_index = np.where(pred_info[:, 4] >= thresh)[0] + if r_index.size == 0: + pr_info[t, 0] = 0 + pr_info[t, 1] = 0 + else: + r_index = r_index[-1] + p_index = np.where(proposal_list[:r_index+1] == 1)[0] + pr_info[t, 0] = len(p_index) + pr_info[t, 1] = pred_recall[r_index] + return pr_info + + +def dataset_pr_info(thresh_num, pr_curve, count_face): + pr_curve_t = np.zeros((thresh_num, 2)) + for i in range(thresh_num): + pr_curve_t[i, 0] = pr_curve[i, 1] / pr_curve[i, 0] + pr_curve_t[i, 1] = pr_curve[i, 1] / count_face + return pr_curve_t + + +def voc_ap(rec, prec): + """ + Voc ap calculation + """ + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def evaluation(pred_evaluation, gt_path, iou_thresh=0.4): + """ + evaluation method. + """ + print_pred = pred_evaluation + pred_evaluation = get_preds(pred_evaluation) + norm_score(pred_evaluation) + facebox_list, event_list, file_list, hard_gt_list, medium_gt_list, easy_gt_list = get_gt_boxes(gt_path) + event_num = len(event_list) + thresh_num = 1000 + setting_gts = [easy_gt_list, medium_gt_list, hard_gt_list] + + aps = [] + for setting_id in range(3): + # different setting + gt_list = setting_gts[setting_id] + count_face = 0 + pr_curve = np.zeros((thresh_num, 2)).astype('float') + # [hard, medium, easy] + # pbar = tqdm.tqdm(range(event_num)) # 61 + pbar = range(event_num) + error_count = 0 + for i in pbar: + event_name = str(event_list[i][0][0]) + img_list = file_list[i][0] + pred_list = pred_evaluation[event_name] + sub_gt_list = gt_list[i][0] + gt_bbx_list = facebox_list[i][0] + + for j, _ in enumerate(img_list): + try: + pred_info = pred_list[str(img_list[j][0][0])] + except KeyError: + error_count += 1 + continue + + gt_boxes = gt_bbx_list[j][0].astype('float') + keep_index = sub_gt_list[j][0] + count_face += len(keep_index) + if gt_boxes.size == 0 or pred_info.size == 0: + continue + ignore = np.zeros(gt_boxes.shape[0]) + if keep_index.size != 0: + ignore[keep_index-1] = 1 + pred_recall, proposal_list = image_eval(pred_info, gt_boxes, ignore, iou_thresh) + + pr_curve += img_pr_info(thresh_num, pred_info, proposal_list, pred_recall) + + pr_curve = dataset_pr_info(thresh_num, pr_curve, count_face) + + propose = pr_curve[:, 0] + recall = pr_curve[:, 1] + + ap = voc_ap(recall, propose) + aps.append(ap) + + print("==================== Results = ====================", print_pred) + print("Easy Val AP: {}".format(aps[0])) + print("Medium Val AP: {}".format(aps[1])) + print("Hard Val AP: {}".format(aps[2])) + print("=================================================") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--pred', default='', + help='test output, txt contain box positions and scores') + parser.add_argument('-g', '--gt', default='', help='ground truth path, mat format') + args = parser.parse_args() + + pred = args.pred + if os.path.isdir(pred): + evaluation(pred, args.gt) + else: + pass diff --git a/model_zoo/official/cv/centerface/denpendency/evaluate/setup.py b/model_zoo/official/cv/centerface/denpendency/evaluate/setup.py new file mode 100644 index 0000000000..5c5eb85bde --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/evaluate/setup.py @@ -0,0 +1,13 @@ +""" +WiderFace evaluation code +author: wondervictor +mail: tianhengcheng@gmail.com +copyright@wondervictor +""" + +from distutils.core import setup, Extension +import numpy +from Cython.Build import cythonize + +package = Extension('bbox', ['box_overlaps.pyx'], include_dirs=[numpy.get_include()]) +setup(ext_modules=cythonize([package])) diff --git a/model_zoo/official/cv/centerface/denpendency/extd/utils/augmentations.py b/model_zoo/official/cv/centerface/denpendency/extd/utils/augmentations.py new file mode 100644 index 0000000000..b79fc37e70 --- /dev/null +++ b/model_zoo/official/cv/centerface/denpendency/extd/utils/augmentations.py @@ -0,0 +1,78 @@ +#EXTD: Extremely Tiny Face Detector via Iterative Filter Reuse +# MIT license + +# Copyright (c) 2019-present NAVER Corp. + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE +"""Augmentations""" + +import random +import numpy as np +import cv2 + +def anchor_crop_image_sampling(image, anns): + """ + Crop anchors. + """ + max_size = 12000 + inf_distance = 9999999 + + boxes = [] + for ann in anns: + boxes.append([ann['bbox'][0], ann['bbox'][1], ann['bbox'][0] + ann['bbox'][2], ann['bbox'][1] + ann['bbox'][3]]) + boxes = np.asarray(boxes, dtype=np.float32) + + height, width, _ = image.shape + + box_area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1) + rand_idx = random.randint(0, len(box_area) - 1) + rand_side = box_area[rand_idx] ** 0.5 + + anchors = [16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 128, 256, 512] + distance = inf_distance + anchor_idx = 5 + for i, anchor in enumerate(anchors): + if abs(anchor - rand_side) < distance: + distance = abs(anchor - rand_side) + anchor_idx = i + + target_anchor = random.choice(anchors[0:min(anchor_idx + 1, 11)]) + ratio = float(target_anchor) / rand_side + ratio = ratio * (2 ** random.uniform(-1, 1)) + + if int(height * ratio * width * ratio) > max_size * max_size: + ratio = (max_size * max_size / (height * width)) ** 0.5 + + interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] + interp_method = random.choice(interp_methods) + image = cv2.resize(image, None, None, fx=ratio, fy=ratio, interpolation=interp_method) + + boxes[:, 0] *= ratio + boxes[:, 1] *= ratio + boxes[:, 2] *= ratio + boxes[:, 3] *= ratio + + boxes = boxes.tolist() + for i, _ in enumerate(anns): + anns[i]['bbox'] = [boxes[i][0], boxes[i][1], boxes[i][2] - boxes[i][0], boxes[i][3] - boxes[i][1]] + for j in range(5): + anns[i]['keypoints'][j * 3] *= ratio + anns[i]['keypoints'][j * 3 + 1] *= ratio + + return image, anns diff --git a/model_zoo/official/cv/centerface/export.py b/model_zoo/official/cv/centerface/export.py new file mode 100644 index 0000000000..a7095d2292 --- /dev/null +++ b/model_zoo/official/cv/centerface/export.py @@ -0,0 +1,62 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Convert ckpt to air.""" +import os +import argparse +import numpy as np + +from mindspore import context +from mindspore import Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net + +from src.centerface import CenterfaceMobilev2 + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + +def save_air(): + """Save air file""" + print('============= centerface start save air ==================') + + parser = argparse.ArgumentParser(description='Convert ckpt to air') + parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') + parser.add_argument('--batch_size', type=int, default=8, help='batch size') + + args = parser.parse_args() + network = CenterfaceMobilev2() + + if os.path.isfile(args.pretrained): + param_dict = load_checkpoint(args.pretrained) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): + continue + elif key.startswith('centerface_network.'): + param_dict_new[key[19:]] = values + else: + param_dict_new[key] = values + load_param_into_net(network, param_dict_new) + print('load model {} success'.format(args.pretrained)) + + input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32) + + tensor_input_data = Tensor(input_data) + export(network, tensor_input_data, + file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR') + + print("export model success.") + + +if __name__ == "__main__": + save_air() diff --git a/model_zoo/official/cv/centerface/scripts/eval.sh b/model_zoo/official/cv/centerface/scripts/eval.sh new file mode 100644 index 0000000000..671b61883b --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/eval.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +root=$PWD +save_path=$root/output/centerface/ +ground_truth_path=$root/dataset/centerface/ground_truth +echo "start eval" +python ../dependency/evaluate/eval.py --pred=$save_path --gt=$ground_truth_path +echo "end eval" diff --git a/model_zoo/official/cv/centerface/scripts/eval_all.sh b/model_zoo/official/cv/centerface/scripts/eval_all.sh new file mode 100644 index 0000000000..993016621b --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/eval_all.sh @@ -0,0 +1,26 @@ +#!/bin/sh +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +root=$PWD +save_path=$root/output/centerface/ +ground_truth_path=$root/dataset/centerface/ground_truth +#for i in $(seq start_epoch end_epoch+1) +for i in $(seq 89 200) +do + python ../dependency/evaluate/eval.py --pred=$save_path$i --gt=$ground_truth_path & + sleep 10 +done +wait diff --git a/model_zoo/official/cv/centerface/scripts/test.sh b/model_zoo/official/cv/centerface/scripts/test.sh new file mode 100644 index 0000000000..c3470c71f4 --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/test.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -gt 6 ] +then + echo "Usage: sh test.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_ID] [CKPT]" + echo " or: sh test.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_ID]" + echo " or: sh test.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH]" + echo " or: sh test.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT]" + echo " or: sh test.sh [MODEL_PATH] [DATASET]" + echo " or: sh test.sh [MODEL_PATH]" + echo " or: sh test.sh " +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +current_exec_path=$(pwd) +echo ${current_exec_path} + +dirname_path=$(dirname "$(pwd)") +echo ${dirname_path} + +SCRIPT_NAME='test.py' + +ulimit -c unlimited + +root=${current_exec_path} # your script path +model_path=$root/model/ +dataset_root=$root/dataset +dataset_path=$dataset_root/centerface/images/val/images/ +ground_truth_mat=$dataset_root/centerface/ground_truth/val.mat +save_path=$root/output/centerface/ +device_id=0 +ckpt="0.ckpt" # the model saved for epoch=125 + +if [ $# == 1 ] +then + model_path=$(get_real_path $1) + if [ ! -f $model_path ] + then + echo "error: model_path=$model_path is not a file" + exit 1 + fi +fi + +if [ $# == 2 ] +then + dataset_path=$(get_real_path $2) + if [ ! -f $dataset_path ] + then + echo "error: dataset_path=$dataset_path is not a file" + exit 1 + fi +fi + +if [ $# == 3 ] +then + ground_truth_mat=$(get_real_path $3) + if [ ! -f $ground_truth_mat ] + then + echo "error: ground_truth_mat=$ground_truth_mat is not a file" + exit 1 + fi +fi + +if [ $# == 4 ] +then + save_path=$(get_real_path $4) + if [ ! -f $save_path ] + then + echo "error: save_path=$save_path is not a file" + exit 1 + fi +fi + +if [ $# == 5 ] +then + device_id=$5 +fi + +if [ $# == 6 ] +then + ckpt=$6 +fi + +echo $model_path +echo $dataset_path +echo $ground_truth_mat +echo $save_path + +export PYTHONPATH=${dirname_path}:$PYTHONPATH +export RANK_SIZE=1 + +echo 'start testing' +rm -rf ${current_exec_path}/device_test$device_id +echo 'start rank '$device_id +mkdir ${current_exec_path}/device_test$device_id +cd ${current_exec_path}/device_test$device_id || exit +export RANK_ID=0 +dev=`expr $device_id + 0` +export DEVICE_ID=$dev +python ${dirname_path}/${SCRIPT_NAME} \ + --is_distributed=0 \ + --data_dir=$dataset_path \ + --test_model=$model_path \ + --ground_truth_mat=$ground_truth_mat \ + --save_dir=$save_path \ + --rank=$device_id \ + --ckpt_name=$ckpt > test.log 2>&1 & + +echo 'running' diff --git a/model_zoo/official/cv/centerface/scripts/test_and_eval.sh b/model_zoo/official/cv/centerface/scripts/test_and_eval.sh new file mode 100644 index 0000000000..973415adc0 --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/test_and_eval.sh @@ -0,0 +1,146 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -gt 6 ] +then + echo "Usage: sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_ID] [CKPT] [GROUND_TRUTH_PATH]" + echo " or: sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_ID] [CKPT]" + echo " or: sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_ID]" + echo " or: sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH]" + echo " or: sh test_and_eval.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT]" + echo " or: sh test_and_eval.sh [MODEL_PATH] [DATASET]" + echo " or: sh test_and_eval.sh [MODEL_PATH]" + echo " or: sh test_and_eval.sh " +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +current_exec_path=$(pwd) +echo ${current_exec_path} + +dirname_path=$(dirname "$(pwd)") +echo ${dirname_path} + +SCRIPT_NAME='test.py' + +ulimit -c unlimited + +root=${current_exec_path} # your script path +model_path=$root/model/ +dataset_root=$root/dataset +dataset_path=$dataset_root/centerface/images/val/images/ +ground_truth_mat=$dataset_root/centerface/ground_truth/val.mat +save_path=$root/output/centerface/999 +device_id=0 +ckpt="0-125_24750.ckpt" # the model saved for epoch=125 +ground_truth_path=$root/dataset/centerface/ground_truth + +if [ $# == 1 ] +then + model_path=$(get_real_path $1) + if [ ! -f $model_path ] + then + echo "error: model_path=$model_path is not a file" + exit 1 + fi +fi + +if [ $# == 2 ] +then + dataset_path=$(get_real_path $2) + if [ ! -f $dataset_path ] + then + echo "error: dataset_path=$dataset_path is not a file" + exit 1 + fi +fi + +if [ $# == 3 ] +then + ground_truth_mat=$(get_real_path $3) + if [ ! -f $ground_truth_mat ] + then + echo "error: ground_truth_mat=$ground_truth_mat is not a file" + exit 1 + fi +fi + +if [ $# == 4 ] +then + save_path=$(get_real_path $4) + if [ ! -f $save_path ] + then + echo "error: save_path=$save_path is not a file" + exit 1 + fi +fi + +if [ $# == 5 ] +then + device_id=$5 +fi + +if [ $# == 6 ] +then + ckpt=$6 +fi + +if [ $# == 7 ] +then + ground_truth_path=$(get_real_path $7) + if [ ! -f $ground_truth_path ] + then + echo "error: ground_truth_path=$ground_truth_path is not a file" + exit 1 + fi +fi + +echo $model_path +echo $dataset_path +echo $ground_truth_mat +echo $save_path +echo $ground_truth_path + +export PYTHONPATH=${dirname_path}:$PYTHONPATH +export RANK_SIZE=1 + +echo 'start testing' +rm -rf ${current_exec_path}/device_test$device_id +echo 'start rank '$device_id +mkdir ${current_exec_path}/device_test$device_id +cd ${current_exec_path}/device_test$device_id || exit +export RANK_ID=0 +dev=`expr $device_id + 0` +export DEVICE_ID=$dev +python ${dirname_path}/${SCRIPT_NAME} \ + --is_distributed=0 \ + --data_dir=$dataset_path \ + --test_model=$model_path \ + --ground_truth_mat=$ground_truth_mat \ + --save_dir=$save_path \ + --rank=$device_id \ + --ckpt_name=$ckpt \ + --eval=1 \ + --ground_truth_path=$ground_truth_path > test.log 2>&1 & + +echo 'running' diff --git a/model_zoo/official/cv/centerface/scripts/test_distribute.sh b/model_zoo/official/cv/centerface/scripts/test_distribute.sh new file mode 100644 index 0000000000..ef51c49f7a --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/test_distribute.sh @@ -0,0 +1,157 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -gt 8 ] +then + echo "Usage: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_NUM] [STEPS_PER_EPOCH] [START] [END]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_NUM] [STEPS_PER_EPOCH] [START]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_NUM] [STEPS_PER_EPOCH]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_NUM]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH] [DEVICE_NUM]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT] [SAVE_PATH]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET] [GROUND_TRUTH_MAT]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET]" + echo " or: sh test_distribute.sh [MODEL_PATH] [DATASET]" + echo " or: sh test_distribute.sh [MODEL_PATH]" + echo " or: sh test_distribute.sh " +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +current_exec_path=$(pwd) +echo ${current_exec_path} + +dirname_path=$(dirname "$(pwd)") +echo ${dirname_path} + +SCRIPT_NAME='test.py' + +ulimit -c unlimited + +root=${current_exec_path} # your script path +model_path=$root/model/ +dataset_root=$root/dataset +dataset_path=$dataset_root/centerface/images/val/images/ +ground_truth_mat=$dataset_root/centerface/ground_truth/val.mat +save_path=$root/output/centerface/ +# blow are used for calculate model name +# model/ckpt name is "0-" + str(ckpt_num) + "_" + str(198*ckpt_num) + ".ckpt"; +# ckpt_num is epoch number, can be calculated by device_num +# detail can be found in "test.py" +device_num=8 +steps_per_epoch=198 #198 for 8P; 1583 for 1p +start=11 # start epoch number = start * device_num + min(device_phy_id) + 1 +end=18 # end epoch number = end * device_num + max(device_phy_id) + 1 + +if [ $# == 1 ] +then + model_path=$(get_real_path $1) + if [ ! -f $model_path ] + then + echo "error: model_path=$model_path is not a file" + exit 1 + fi +fi + +if [ $# == 2 ] +then + dataset_path=$(get_real_path $2) + if [ ! -f $dataset_path ] + then + echo "error: dataset_path=$dataset_path is not a file" + exit 1 + fi +fi + +if [ $# == 3 ] +then + ground_truth_mat=$(get_real_path $3) + if [ ! -f $ground_truth_mat ] + then + echo "error: ground_truth_mat=$ground_truth_mat is not a file" + exit 1 + fi +fi + +if [ $# == 4 ] +then + save_path=$(get_real_path $4) + if [ ! -f $save_path ] + then + echo "error: save_path=$save_path is not a file" + exit 1 + fi +fi + +if [ $# == 5 ] +then + device_num=$5 +fi + +if [ $# == 6 ] +then + steps_per_epoch=$6 +fi + +if [ $# == 7 ] +then + start=$7 +fi + +if [ $# == 8 ] +then + end=$8 +fi + +echo $model_path +echo $dataset_path +echo $ground_truth_mat +echo $save_path + +export PYTHONPATH=${dirname_path}:$PYTHONPATH +export RANK_SIZE=1 + +echo 'start testing' +rm -rf ${current_exec_path}/device_test* +for((i=0;i<=$device_num-1;i++)); +do + echo 'start rank '$i + mkdir ${current_exec_path}/device_test$i + cd ${current_exec_path}/device_test$i || exit + export RANK_ID=0 + dev=`expr $i + 0` + export DEVICE_ID=$dev + python ${dirname_path}/${SCRIPT_NAME} \ + --is_distributed=0 \ + --data_dir=$dataset_path \ + --test_model=$model_path \ + --ground_truth_mat=$ground_truth_mat \ + --save_dir=$save_path \ + --rank=$i \ + --device_num=$device_num \ + --steps_per_epoch=$steps_per_epoch \ + --start=$start \ + --end=$end > test.log 2>&1 & +done + +echo 'running' diff --git a/model_zoo/official/cv/centerface/scripts/train_distribute.sh b/model_zoo/official/cv/centerface/scripts/train_distribute.sh new file mode 100644 index 0000000000..8382089a3b --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/train_distribute.sh @@ -0,0 +1,141 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 0 ] && [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ] +then + echo "Usage: sh train_distribute.sh [RANK_TABLE] [PRETRAINED_BACKBONE] [DATASET] [ANNOTATIONS] [IMAGES]" + echo " or: sh train_distribute.sh [RANK_TABLE] [PRETRAINED_BACKBONE] [DATASET] [ANNOTATIONS]" + echo " or: sh train_distribute.sh [RANK_TABLE] [PRETRAINED_BACKBONE] [DATASET]" + echo " or: sh train_distribute.sh [RANK_TABLE] [PRETRAINED_BACKBONE]" + echo " or: sh train_distribute.sh [RANK_TABLE]" + echo " or: sh train_distribute.sh " +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +current_exec_path=$(pwd) +echo ${current_exec_path} + +dirname_path=$(dirname "$(pwd)") +echo ${dirname_path} + +rm -rf ${current_exec_path}/device* +SCRIPT_NAME='train.py' + +ulimit -c unlimited + +root=${current_exec_path} # your script path +pretrained_backbone=${dirname_path}/mobilenet_v2.ckpt # or mobilenet_v2-b0353104.ckpt +dataset_path=$root/dataset/centerface +annot_path=$dataset_path/annotations/train.json +img_dir=$dataset_path/images/train/images +rank_table=$root/rank_table_8p.json + +if [ $# == 1 ] +then + rank_table=$(get_real_path $1) + if [ ! -f $rank_table ] + then + echo "error: rank_table=$rank_table is not a file" + exit 1 + fi +fi + +if [ $# == 2 ] +then + pretrained_backbone=$(get_real_path $2) + if [ ! -f $pretrained_backbone ] + then + echo "error: pretrained_backbone=$pretrained_backbone is not a file" + exit 1 + fi +fi + +if [ $# == 3 ] +then + dataset_path=$(get_real_path $3) + if [ ! -f $dataset_path ] + then + echo "error: dataset_path=$dataset_path is not a file" + exit 1 + fi +fi + +if [ $# == 4 ] +then + annot_path=$(get_real_path $4) + if [ ! -f $annot_path ] + then + echo "error: annot_path=$annot_path is not a file" + exit 1 + fi +fi + +if [ $# == 5 ] +then + img_dir=$(get_real_path $5) + if [ ! -f $img_dir ] + then + echo "error: img_dir=$img_dir is not a file" + exit 1 + fi +fi + +echo $rank_table +echo $pretrained_backbone +echo $dataset_path +echo $annot_path +echo $img_dir + +export PYTHONPATH=${dirname_path}:$PYTHONPATH +export RANK_TABLE_FILE=$rank_table +export RANK_SIZE=8 + +task_set_core=24 # for taskset, task_set_core=total cpu number/RANK_SIZE +echo 'start training' +for((i=0;i<=$RANK_SIZE-1;i++)); +do + echo 'start rank '$i + mkdir ${current_exec_path}/device$i + cd ${current_exec_path}/device$i || exit + export RANK_ID=$i + dev=`expr $i + 0` + export DEVICE_ID=$dev + taskset -c $((i*task_set_core))-$(((i+1)*task_set_core-1)) python ${dirname_path}/${SCRIPT_NAME} \ + --lr=4e-3 \ + --per_batch_size=8 \ + --is_distributed=1 \ + --t_max=140 \ + --max_epoch=140 \ + --warmup_epochs=0 \ + --lr_scheduler=multistep \ + --lr_epochs=90,120 \ + --weight_decay=0.0000 \ + --loss_scale=1024 \ + --pretrained_backbone=$pretrained_backbone \ + --data_dir=$dataset_path \ + --annot_path=$annot_path \ + --img_dir=$img_dir > train.log 2>&1 & +done + +echo 'running' diff --git a/model_zoo/official/cv/centerface/scripts/train_standalone.sh b/model_zoo/official/cv/centerface/scripts/train_standalone.sh new file mode 100644 index 0000000000..31446a8a71 --- /dev/null +++ b/model_zoo/official/cv/centerface/scripts/train_standalone.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 0 ] && [ $# != 1 ] && [ $# != 2 ] && [ $# != 3 ] && [ $# != 4 ] && [ $# != 5 ] +then + echo "Usage: sh train_standalone.sh [USE_DEVICE_ID] [PRETRAINED_BACKBONE] [DATASET] [ANNOTATIONS] [IMAGES]" + echo " or: sh train_standalone.sh [USE_DEVICE_ID] [PRETRAINED_BACKBONE] [DATASET] [ANNOTATIONS]" + echo " or: sh train_standalone.sh [USE_DEVICE_ID] [PRETRAINED_BACKBONE] [DATASET]" + echo " or: sh train_standalone.sh [USE_DEVICE_ID] [PRETRAINED_BACKBONE]" + echo " or: sh train_standalone.sh [USE_DEVICE_ID]" + echo " or: sh train_standalone.sh " +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +current_exec_path=$(pwd) +echo ${current_exec_path} + +dirname_path=$(dirname "$(pwd)") +echo ${dirname_path} + +SCRIPT_NAME='train.py' + +ulimit -c unlimited + +root=${current_exec_path} # your script path +pretrained_backbone=${dirname_path}/mobilenet_v2.ckpt # or mobilenet_v2-b0353104.ckpt +dataset_path=$root/dataset/centerface +annot_path=$dataset_path/annotations/train.json +img_dir=$dataset_path/images/train/images +use_device_id=0 + +if [ $# == 1 ] +then + use_device_id=$1 +fi + +if [ $# == 2 ] +then + pretrained_backbone=$(get_real_path $2) + if [ ! -f $pretrained_backbone ] + then + echo "error: pretrained_backbone=$pretrained_backbone is not a file" + exit 1 + fi +fi + +if [ $# == 3 ] +then + dataset_path=$(get_real_path $3) + if [ ! -f $dataset_path ] + then + echo "error: dataset_path=$dataset_path is not a file" + exit 1 + fi +fi + +if [ $# == 4 ] +then + annot_path=$(get_real_path $4) + if [ ! -f $annot_path ] + then + echo "error: annot_path=$annot_path is not a file" + exit 1 + fi +fi + +if [ $# == 5 ] +then + img_dir=$(get_real_path $5) + if [ ! -f $img_dir ] + then + echo "error: img_dir=$img_dir is not a file" + exit 1 + fi +fi + +echo $use_device_id +echo $pretrained_backbone +echo $dataset_path +echo $annot_path +echo $img_dir + +export PYTHONPATH=${dirname_path}:$PYTHONPATH +export RANK_SIZE=1 + +echo 'start training' +echo 'start rank '$use_device_id +rm -rf ${current_exec_path}/device$use_device_id +mkdir ${current_exec_path}/device$use_device_id +cd ${current_exec_path}/device$use_device_id || exit +export RANK_ID=0 +dev=`expr $use_device_id + 0` +export DEVICE_ID=$dev +python ${dirname_path}/${SCRIPT_NAME} \ + --lr=5e-4 \ + --per_batch_size=8 \ + --is_distributed=0 \ + --t_max=140 \ + --max_epoch=140 \ + --warmup_epochs=0 \ + --lr_scheduler=multistep \ + --lr_epochs=90,120 \ + --weight_decay=0.0000 \ + --loss_scale=1024 \ + --pretrained_backbone=$pretrained_backbone \ + --data_dir=$dataset_path \ + --annot_path=$annot_path \ + --img_dir=$img_dir > train.log 2>&1 & + +echo 'running' diff --git a/model_zoo/official/cv/centerface/src/__init__.py b/model_zoo/official/cv/centerface/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/centerface/src/centerface.py b/model_zoo/official/cv/centerface/src/centerface.py new file mode 100644 index 0000000000..e5742cd08f --- /dev/null +++ b/model_zoo/official/cv/centerface/src/centerface.py @@ -0,0 +1,324 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""centerface networks""" + +from src.config import ConfigCenterface +from src.mobile_v2 import mobilenet_v2 +from src.losses import FocalLoss, SmoothL1LossNew, SmoothL1LossNewCMask + +import mindspore as ms +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore import context +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_group_size +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual +from mindspore.context import ParallelMode + +_grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + +@_grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * reciprocal(scale) + +def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False): + return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias, + padding=padding, pad_mode="pad") + + +def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False): + return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias, + padding=padding, pad_mode="pad") + + +def convTranspose2x2(in_channels, out_channels, has_bias=False): # Davinci devices only support 'groups=1' + return nn.Conv2dTranspose(in_channels, out_channels, kernel_size=2, stride=2, has_bias=has_bias, + weight_init='normal', bias_init='zeros') + + +class IDAUp(nn.Cell): + """ + IDA Module. + """ + def __init__(self, out_dim, channel): + super(IDAUp, self).__init__() + self.out_dim = out_dim + self.up = nn.SequentialCell([ + convTranspose2x2(out_dim, out_dim, has_bias=False), + nn.BatchNorm2d(out_dim, eps=0.001, momentum=0.9).add_flags_recursive(fp32=True), + nn.ReLU()]) + self.conv = nn.SequentialCell([ + conv1x1(channel, out_dim), + nn.BatchNorm2d(out_dim, eps=0.001, momentum=0.9).add_flags_recursive(fp32=True), + nn.ReLU()]) + + def construct(self, x0, x1): + x = self.up(x0) + y = self.conv(x1) + out = x + y + return out + + +class MobileNetUp(nn.Cell): + """ + Mobilenet module. + """ + def __init__(self, channels, out_dim=24): + super(MobileNetUp, self).__init__() + channels = channels[::-1] + self.conv = nn.SequentialCell([ + conv1x1(channels[0], out_dim), + nn.BatchNorm2d(out_dim, eps=0.001).add_flags_recursive(fp32=True), + nn.ReLU()]) + self.conv_last = nn.SequentialCell([ + conv3x3(out_dim, out_dim), + nn.BatchNorm2d(out_dim, eps=1e-5, momentum=0.99).add_flags_recursive(fp32=True), + nn.ReLU()]) + + self.up1 = IDAUp(out_dim, channels[1]) + self.up2 = IDAUp(out_dim, channels[2]) + self.up3 = IDAUp(out_dim, channels[3]) + + def construct(self, x1, x2, x3, x4): # tuple/list can be type of input of a subnet + x = self.conv(x4) # top_layer, change outdim + + x = self.up1(x, x3) + x = self.up2(x, x2) + x = self.up3(x, x1) + x = self.conv_last(x) + return x + +class Cast(nn.Cell): + def __init__(self): + super(Cast, self).__init__() + self.cast = P.Cast() + + def construct(self, x): + return self.cast(x, ms.float32) + +class CenterfaceMobilev2(nn.Cell): + """ + Mobilev2 based CenterFace network. + + Args: + num_classes: Integer. Class number. + feature_shape: List. Input image shape, [N,C,H,W]. + + Returns: + Cell, cell instance of Darknet based YOLOV3 neural network. + CenterFace use the same structure. + + Examples: + yolov3_darknet53(80, [1,3,416,416]) + + """ + + def __init__(self): + super(CenterfaceMobilev2, self).__init__() + self.config = ConfigCenterface() + + self.base = mobilenet_v2() + channels = self.base.feat_channel + self.dla_up = MobileNetUp(channels, out_dim=self.config.head_conv) + + self.hm_head = nn.SequentialCell([conv1x1(self.config.head_conv, 1, has_bias=True), + nn.Sigmoid().add_flags_recursive(fp32=True)]) + self.wh_head = conv1x1(self.config.head_conv, 2, has_bias=True) + self.off_head = conv1x1(self.config.head_conv, 2, has_bias=True) + self.kps_head = conv1x1(self.config.head_conv, 10, has_bias=True) + + def construct(self, x): + x1, x2, x3, x4 = self.base(x) + x = self.dla_up(x1, x2, x3, x4) + + output_hm = self.hm_head(x) + output_wh = self.wh_head(x) + output_off = self.off_head(x) + output_kps = self.kps_head(x) + return output_hm, output_wh, output_off, output_kps + +class CenterFaceLoss(nn.Cell): + """ + Loss method defination. + """ + def __init__(self, wh_weight, reg_offset, off_weight, hm_weight, lm_weight): + super(CenterFaceLoss, self).__init__() + # --- config parameter + self.wh_weight = wh_weight + self.reg_offset = reg_offset + self.off_weight = off_weight + self.hm_weight = hm_weight + self.lm_weight = lm_weight + # --- + self.cls_loss = FocalLoss() + self.reg_loss = SmoothL1LossNew() + self.reg_loss_cmask = SmoothL1LossNewCMask() + self.print = P.Print() + # self.reduce_sum = P.ReduceSum() + + def construct(self, output_hm, output_wh, output_off, output_kps, hm, reg_mask, ind, wh, wight_mask, hm_offset, + hps_mask, landmarks): + """ + Construct method. + """ + hm_loss = self.cls_loss(output_hm, hm) # 1. focal loss, center points + wh_loss = self.reg_loss(output_wh, ind, wh, wight_mask) # 2. weight and height + off_loss = self.reg_loss(output_off, ind, hm_offset, wight_mask) # 3. offset + lm_loss = self.reg_loss_cmask(output_kps, hps_mask, ind, landmarks) # 4. landmark loss + + loss = self.hm_weight * hm_loss + self.wh_weight * wh_loss + \ + self.off_weight * off_loss + self.lm_weight * lm_loss + + # depend is needed when wight_mask and reg_mask is not been used + F.depend(loss, F.sqrt(F.cast(wight_mask, mstype.float32))) + F.depend(loss, F.sqrt(F.cast(reg_mask, mstype.float32))) + # add print when you want to see loss detail and do debug + #self.print('hm_loss=', hm_loss, 'wh_loss=', wh_loss, 'off_loss=', off_loss, 'lm_loss=', lm_loss, 'loss=', loss) + return loss + + +class CenterFaceWithLossCell(nn.Cell): + """ + Centerface with loss cell. + """ + def __init__(self, network): + super(CenterFaceWithLossCell, self).__init__() + self.centerface_network = network + self.config = ConfigCenterface() + self.loss = CenterFaceLoss(self.config.wh_weight, self.config.reg_offset, self.config.off_weight, + self.config.hm_weight, self.config.lm_weight) + self.reduce_sum = P.ReduceSum() + self.print = P.Print() + + def construct(self, x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks): + output_hm, output_wh, output_off, output_kps = self.centerface_network(x) + loss = self.loss(output_hm, output_wh, output_off, output_kps, hm, reg_mask, ind, wh, wight_mask, hm_offset, + hps_mask, landmarks) + return loss + +class TrainingWrapper(nn.Cell): + """ + Training wrapper + """ + def __init__(self, network, optimizer, sens=1.0): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() #False + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + if self.reducer_flag: + mean = context.get_auto_parallel_context("gradients_mean") + if auto_parallel_context().get_device_num_is_set(): + degree = context.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + self.hyper_map = C.HyperMap() + self.alloc_status = NPUAllocFloatStatus() + self.get_status = NPUGetFloatStatus() + self.clear_status = NPUClearFloatStatus() + self.reduce_sum = ReduceSum(keep_dims=False) + self.base = Tensor(1, mstype.float32) + self.less_equal = LessEqual() + self.allreduce = P.AllReduce() + self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE + + # x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks + def construct(self, x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks): + """ + Construct method. + """ + weights = self.weights + loss = self.network(x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks) + + # init overflow buffer + init = self.alloc_status() + # clear overflow buffer + self.clear_status(init) + + #sens = sens_input #P.Fill()(P.DType()(loss), P.Shape()(loss), sens_input) # user can contral loss scale by add a sens_input + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(x, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks, + sens) + #grads = self.hyper_map(F.partial(_grad_scale, sens), grads) # if add this, the loss_scale optimizer is needed to set to 1 + if self.reducer_flag: + grads = self.grad_reducer(grads) + + # get the overflow buffer + self.get_status(init) + # sum overflow buffer elements, 0:not overflow , >0:overflow + flag_sum = self.reduce_sum(init, (0,)) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + + ret = (loss, cond, sens) + return F.depend(ret, self.optimizer(grads)) + + +class CenterFaceWithNms(nn.Cell): + """ + CenterFace with nms. + """ + def __init__(self, network): + super(CenterFaceWithNms, self).__init__() + self.centerface_network = network + self.config = ConfigCenterface() + # two type of maxpool self.maxpool2d = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='same') + self.maxpool2d = P.MaxPoolWithArgmax(ksize=3, strides=1, padding='same') + self.topk = P.TopK(sorted=True) + self.reshape = P.Reshape() + self.print = P.Print() + self.test_batch = self.config.test_batch_size + self.k = self.config.K + + def construct(self, x): + """ + Construct method. + """ + output_hm, output_wh, output_off, output_kps = self.centerface_network(x) + output_hm_nms, _ = self.maxpool2d(output_hm) + abs_error = P.Abs()(output_hm - output_hm_nms) + abs_out = P.Abs()(output_hm) + error = abs_error / (abs_out + 1e-12) + + # cannot use P.Equal()(output_hm, output_hm_nms), since maxpooling output has 0.1% error + keep = P.Select()(P.LessEqual()(error, 1e-3), \ + P.Fill()(ms.float32, P.Shape()(error), 1.0), \ + P.Fill()(ms.float32, P.Shape()(error), 0.0)) + output_hm = output_hm * keep + + # get topK and index + scores = self.reshape(output_hm, (self.test_batch, -1)) + topk_scores, topk_inds = self.topk(scores, self.k) + return topk_scores, output_wh, output_off, output_kps, topk_inds diff --git a/model_zoo/official/cv/centerface/src/config.py b/model_zoo/official/cv/centerface/src/config.py new file mode 100644 index 0000000000..ce4dad5ed5 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/config.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""centerface unique configs""" + +class ConfigCenterface(): + """ + Config setup + """ + flip_idx = [[0, 1], [3, 4]] + default_resolution = [512, 512] + heads = {'hm': 1, 'wh': 2, 'hm_offset': 2, 'landmarks': 5 * 2} + head_conv = 64 + max_objs = 64 + + rand_crop = True + scale = 0.4 + shift = 0.1 + aug_rot = 0 + color_aug = True + flip = 0.5 + input_res = 512 #768 #800 + output_res = 128 #192 #200 + num_classes = 1 + num_joints = 5 + reg_offset = True + hm_hp = True + reg_hp_offset = True + dense_hp = False + hm_weight = 1.0 + wh_weight = 0.1 + off_weight = 1.0 + lm_weight = 0.1 + rotate = 0 + + # for test + mean = [0.408, 0.447, 0.470] + std = [0.289, 0.274, 0.278] + test_scales = [0.999,] + nms = 1 + flip_test = 0 + fix_res = True + input_h = 832 #800 + input_w = 832 #800 + K = 200 + down_ratio = 4 + test_batch_size = 1 + + seed = 317 + master_batch_size = 8 + num_workers = 8 + not_rand_crop = False + no_color_aug = False diff --git a/model_zoo/official/cv/centerface/src/convert_weight.py b/model_zoo/official/cv/centerface/src/convert_weight.py new file mode 100644 index 0000000000..c3dcc8c872 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/convert_weight.py @@ -0,0 +1,181 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Centerface model transform +""" +import os +import argparse +import torch +from mindspore.train.serialization import load_checkpoint, save_checkpoint +from mindspore import Tensor + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--ckpt_fn', type=str, default='/model_path/centerface.ckpt', + help='ckpt for user to get cell/module name') +parser.add_argument('--pt_fn', type=str, default='/model_path/centerface.pth', help='checkpoint filename to convert') +parser.add_argument('--out_fn', type=str, default='/model_path/centerface_out.ckpt', + help='convert output ckpt/pth path') +parser.add_argument('--pt2ckpt', type=int, default=1, help='1 : pt2ckpt; 0 : ckpt2pt') + +args = parser.parse_args() + +def load_model(model_path): + """ + Load model + """ + checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) + print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch'])) + state_dict_ = checkpoint['state_dict'] + state_dict = {} + + # convert data_parallal to model + for k in state_dict_: + if k.find("num_batches_tracked") != -1: + continue + elif k.startswith('module') and not k.startswith('module_list'): + state_dict[k[7:]] = state_dict_[k] + else: + state_dict[k] = state_dict_[k] + + return state_dict + +def save_model(path, epoch=0, model=None, optimizer=None, state_dict=None): + """ + Sace model file + """ + if state_dict is None: + if isinstance(model, torch.nn.DataParallel): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + data = {'epoch': epoch, + 'state_dict': state_dict} + if not optimizer is None: + data['optimizer'] = optimizer.state_dict() + torch.save(data, path) + +def load_model_ms(model_path): + """ + Load mindspore model + """ + state_dict_useless = ['global_step', 'learning_rate', + 'beta1_power', 'beta2_power'] + if os.path.isfile(model_path): + param_dict = load_checkpoint(model_path) + param_dict_new = {} + for key, values in param_dict.items(): + if key in state_dict_useless or key.startswith('moments.') \ + or key.startswith('moment1.') or key.startswith('moment2.'): + continue + elif key.startswith('centerface_network.'): + param_dict_new[key[19:]] = values + else: + param_dict_new[key] = values + else: + assert FileNotFoundError('{} not exists or not a pre-trained file'.format(model_path)) + exit(1) + return param_dict_new + +def name_map(ckpt): + """ + Name map + """ + out = {} + for name in ckpt: + # conv + bn + pt_name = name + # backbone + pt_name = pt_name.replace('need_fp1', 'feature_1') + pt_name = pt_name.replace('need_fp2', 'feature_2') + pt_name = pt_name.replace('need_fp3', 'feature_4') + pt_name = pt_name.replace('need_fp4', 'feature_6') + pt_name = pt_name.replace('.features', '') + pt_name = pt_name.replace('.moving_mean', '.running_mean') + pt_name = pt_name.replace('.moving_variance', '.running_var') + pt_name = pt_name.replace('.gamma', '.weight') + pt_name = pt_name.replace('.beta', '.bias') + # fpn + pt_name = pt_name.replace('.up1', '.up_0') + pt_name = pt_name.replace('.up2', '.up_1') + pt_name = pt_name.replace('.up3', '.up_2') + # heads + pt_name = pt_name.replace('hm_head.0.', 'hm.') + pt_name = pt_name.replace('wh_head.', 'wh.') + pt_name = pt_name.replace('off_head.', 'hm_offset.') + pt_name = pt_name.replace('kps_head.', 'landmarks.') + + out[pt_name] = name + return out + +def pt_to_ckpt(pt, ckpt, out_path): + """ + Pt convert to ckpt file + """ + state_dict_torch = load_model(pt) + state_dict_ms = load_model_ms(ckpt) + name_relate = name_map(state_dict_ms) + + new_params_list = [] + for key in state_dict_torch: + param_dict = {} + parameter = state_dict_torch[key] + parameter = parameter.numpy() + + # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout + if state_dict_ms[name_relate[key]].data.shape != parameter.shape: + parameter = parameter.transpose(1, 0, 2, 3) + print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key) + + param_dict['name'] = name_relate[key] + param_dict['data'] = Tensor(parameter) + new_params_list.append(param_dict) + + save_checkpoint(new_params_list, out_path) + return state_dict_ms + +def ckpt_to_pt(pt, ckpt, out_path): + """ + Ckpt convert to pt file + """ + state_dict_torch = load_model(pt) + state_dict_ms = load_model_ms(ckpt) + name_relate = name_map(state_dict_ms) + + state_dict = {} + for key in state_dict_torch: + name = name_relate[key] + parameter = state_dict_ms[name].data + parameter = parameter.asnumpy() + if state_dict_ms[name_relate[key]].data.shape != state_dict_torch[key].numpy().shape: + print('before ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', + state_dict_torch[key].numpy().shape, 'name=', key) + parameter = parameter.transpose(1, 0, 2, 3) + print('after ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', + state_dict_torch[key].numpy().shape, 'name=', key) + + state_dict[key] = torch.from_numpy(parameter) + + save_model(out_path, epoch=0, model=None, optimizer=None, state_dict=state_dict) + + return state_dict + +if __name__ == "__main__": + if args.pt2ckpt == 1: + pt_to_ckpt(args.pt_fn, args.ckpt_fn, args.out_fn) + elif args.pt2ckpt == 0: + ckpt_to_pt(args.pt_fn, args.ckpt_fn, args.out_fn) + else: + # user defined functions + pass diff --git a/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py b/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py new file mode 100644 index 0000000000..e4ef3bf427 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/convert_weight_mobilenetv2.py @@ -0,0 +1,138 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Mobilenet model transform: torch => mindspore +""" +import os +import argparse +import torch +from mindspore.train.serialization import load_checkpoint, save_checkpoint +from mindspore import Tensor + +parser = argparse.ArgumentParser(description='') +parser.add_argument('--ckpt_fn', type=str, default='/model_path/mobilenet_v2_key.ckpt', + help='ckpt for user to get cell/module name') +parser.add_argument('--pt_fn', type=str, default='/model_path/mobilenet_v2-b0353104.pth', + help='checkpoint filename to convert') +parser.add_argument('--out_ckpt_fn', type=str, default='/model_path/mobilenet_v2-b0353104.ckpt', + help='convert output ckpt path') + +args = parser.parse_args() + +def load_model(model_path): + """ + Load model + """ + state_dict_ = torch.load(model_path, map_location=torch.device('cpu')) + state_dict = {} + + # convert data_parallal to model + for k in state_dict_: + if k.find("num_batches_tracked") != -1: + continue + elif k.startswith('module') and not k.startswith('module_list'): + state_dict[k[7:]] = state_dict_[k] + else: + state_dict[k] = state_dict_[k] + return state_dict + +def load_model_ms(model_path): + """ + Load mindspore model + """ + state_dict_useless = ['global_step', 'learning_rate', + 'beta1_power', 'beta2_power'] + if os.path.isfile(model_path): + param_dict = load_checkpoint(model_path) + param_dict_new = {} + for key, values in param_dict.items(): + if key in state_dict_useless or key.startswith('moments.') \ + or key.startswith('moment1.') or key.startswith('moment2.'): + continue + elif key.startswith('centerface_network.'): #useless, since the start name is "network.backbone." + param_dict_new[key[19:]] = values + else: + param_dict_new[key] = values + else: + assert FileNotFoundError('{} not exists or not a pre-trained file'.format(model_path)) + exit(1) + return param_dict_new + +def name_map(ckpt): + """ + Name map + """ + out = {} + for name in ckpt: + # conv + bn + pt_name = name + + pt_name = pt_name.replace('network.backbone.', '') + # backbone + pt_name = pt_name.replace('need_fp1', 'feature_1') + pt_name = pt_name.replace('need_fp2', 'feature_2') + pt_name = pt_name.replace('need_fp3', 'feature_4') + pt_name = pt_name.replace('need_fp4', 'feature_6') + pt_name = pt_name.replace('.features', '') + pt_name = pt_name.replace('.moving_mean', '.running_mean') + pt_name = pt_name.replace('.moving_variance', '.running_var') + pt_name = pt_name.replace('.gamma', '.weight') + pt_name = pt_name.replace('.beta', '.bias') + # fpn + pt_name = pt_name.replace('.up1', '.up_0') + pt_name = pt_name.replace('.up2', '.up_1') + pt_name = pt_name.replace('.up3', '.up_2') + + # heads + pt_name = pt_name.replace('hm_head.0.', 'hm.') + pt_name = pt_name.replace('wh_head.', 'wh.') + pt_name = pt_name.replace('off_head.', 'hm_offset.') + pt_name = pt_name.replace('kps_head.', 'landmarks.') + + pt_name = pt_name.replace('network.head.fc.', 'classifier.1.') + + out[pt_name] = name + return out + +def pt_to_ckpt(pt, ckpt, out_ckpt): + """ + Pt convert to ckpt file + """ + state_dict_torch = load_model(pt) + state_dict_ms = load_model_ms(ckpt) + name_relate = name_map(state_dict_ms) + new_params_list = [] + + for key in state_dict_torch: + param_dict = {} + parameter = state_dict_torch[key] + parameter = parameter.numpy() + + # depwise conv pytorch[cout, 1, k , k] -> ms[1, cin, k , k], cin = cout + if state_dict_ms[name_relate[key]].data.shape != parameter.shape: + parameter = parameter.transpose(1, 0, 2, 3) + print('ms=', state_dict_ms[name_relate[key]].data.shape, 'pytorch=', parameter.shape, 'name=', key) + + + param_dict['name'] = name_relate[key] + param_dict['data'] = Tensor(parameter) + new_params_list.append(param_dict) + + save_checkpoint(new_params_list, out_ckpt) + return state_dict_ms + +if __name__ == "__main__": + # beta <=> bias, gamma <=> weight + pt_to_ckpt(args.pt_fn, args.ckpt_fn, args.out_ckpt_fn) diff --git a/model_zoo/official/cv/centerface/src/dataset.py b/model_zoo/official/cv/centerface/src/dataset.py new file mode 100644 index 0000000000..9ba0f4619d --- /dev/null +++ b/model_zoo/official/cv/centerface/src/dataset.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate dataloader and data processing entry""" + +import mindspore.dataset.engine as de + +from src.utils import DistributedSampler + +from dependency.centernet.src.lib.datasets.dataset.coco_hp import CenterfaceDataset +from dependency.centernet.src.lib.datasets.sample.multi_pose import preprocess_train + +def GetDataLoader(per_batch_size, + max_epoch, + rank, + group_size, + config, + split='train'): + """ + Centerface get data loader + """ + centerface_gen = CenterfaceDataset(config=config, split=split) + sampler = DistributedSampler(centerface_gen, rank, group_size, shuffle=(split == 'train')) # user defined sampling strategy + de_dataset = de.GeneratorDataset(centerface_gen, ["image", "anns"], sampler=sampler, num_parallel_workers=16) + + if group_size > 1: + num_parallel_workers = 24 + else: + num_parallel_workers = 64 + if split == 'train': + compose_map_func = (lambda image, anns: preprocess_train(image, anns, config=config)) + columns = ['image', "hm", 'reg_mask', 'ind', 'wh', 'wight_mask', 'hm_offset', 'hps_mask', 'landmarks'] + de_dataset = de_dataset.map(input_columns=["image", "anns"], + output_columns=columns, + column_order=columns, + operations=compose_map_func, + num_parallel_workers=num_parallel_workers, + python_multiprocessing=True) + + de_dataset = de_dataset.batch(per_batch_size, drop_remainder=True, num_parallel_workers=8) + if split == 'train': + #de_dataset = de_dataset.repeat(1) # if use this, need an additional "for" cycle epoch times + de_dataset = de_dataset.repeat(max_epoch) + + return de_dataset, de_dataset.get_dataset_size() diff --git a/model_zoo/official/cv/centerface/src/losses.py b/model_zoo/official/cv/centerface/src/losses.py new file mode 100644 index 0000000000..dbde6870c4 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/losses.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""losses for centerface""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +# focal loss: afa=2, beta=4 +class FocalLoss(nn.Cell): + '''nn.Cell warpper for focal loss''' + def __init__(self): + super(FocalLoss, self).__init__() + self.log = P.Log() + self.pow = P.Pow() + self.sum = P.ReduceSum() + self.print = P.Print() + + def construct(self, pred, gt): + """Construct method""" + pos_inds = P.Select()(P.Equal()(gt, 1.0), P.Fill()(P.DType()(gt), P.Shape()(gt), 1.0), P.Fill()(P.DType()(gt), + P.Shape()(gt), + 0.0)) + neg_inds = P.Select()(P.Less()(gt, 1.0), P.Fill()(P.DType()(gt), P.Shape()(gt), 1.0), P.Fill()(P.DType()(gt), + P.Shape()(gt), + 0.0)) + + neg_weights = self.pow(1 - gt, 4) # beta=4 + # afa=2 + pos_loss = self.log(pred) * self.pow(1 - pred, 2) * pos_inds + neg_loss = self.log(1 - pred) * self.pow(pred, 2) * neg_weights * neg_inds + + num_pos = self.sum(pos_inds, ()) + num_pos = P.Select()(P.Equal()(num_pos, 0.0), P.Fill()(P.DType()(num_pos), P.Shape()(num_pos), 1.0), num_pos) + + pos_loss = self.sum(pos_loss, ()) + neg_loss = self.sum(neg_loss, ()) + loss = - (pos_loss + neg_loss) / num_pos + return loss + +class SmoothL1LossNew(nn.Cell): + """Smoothl1loss""" + def __init__(self): + super(SmoothL1LossNew, self).__init__() + self.transpose = P.Transpose() + self.smooth_l1_loss = nn.SmoothL1Loss() + self.shape = P.Shape() + self.expand_dims = P.ExpandDims() + self.sum = P.ReduceSum() + self.cast = P.Cast() + + def construct(self, output, ind, target, wight_mask=None): + ''' + :param output: [b, c, h, w] to [b, h, w, c] + :param ind: + :param target: + :return: + ''' + output = self.transpose(output, (0, 2, 3, 1)) + # dim = self.shape(output)[3] + mask = P.Select()(P.Equal()(ind, 1), P.Fill()(mstype.float32, P.Shape()(ind), 1.0), P.Fill()(mstype.float32, + P.Shape()(ind), + 0.0)) + # ind = self.cast(ind, mstype.float32) + target = self.cast(target, mstype.float32) + output = self.cast(output, mstype.float32) + num = self.cast(self.sum(mask, ()), mstype.float32) + mask = self.expand_dims(mask, -1) # [batch,h,w]--[batch,h,w,c] + output = output * mask + target = target * mask + loss = self.smooth_l1_loss(output, target) + if wight_mask is not None: + loss = loss * wight_mask + loss = self.sum(loss, ()) + else: + #some version need: F.depend(loss, F.sqrt(F.cast(wight_mask, mstype.float32))) + loss = self.sum(loss, ()) + loss = loss / (num + 1e-4) + return loss + +class SmoothL1LossNewCMask(nn.Cell): + """Smoothl1loss with mask""" + def __init__(self): + super(SmoothL1LossNewCMask, self).__init__() + self.transpose = P.Transpose() + self.smooth_l1_loss = nn.L1Loss(reduction='sum') # or use nn.SmoothL1Loss() + self.shape = P.Shape() + self.expand_dims = P.ExpandDims() + self.sum = P.ReduceSum() + self.cast = P.Cast() + + def construct(self, output, cmask, ind, target): + ''' + :param output: [b, c, h, w] to [b, h, w, c] + :param ind: + :param target: + :return: + ''' + num = self.sum(cmask, ()) + output = self.transpose(output, (0, 2, 3, 1)) + + ind = self.cast(ind, mstype.float32) + target = self.cast(target, mstype.float32) + cmask = self.cast(cmask, mstype.float32) + output = self.cast(output, mstype.float32) + ind = self.expand_dims(ind, -1) + output = output * ind + target = target * ind + loss = self.smooth_l1_loss(output*cmask, target*cmask) + #loss = self.sum(loss, ()) # if use SmoothL1Loss, this is needed + loss = loss / (num + 1e-4) + return loss diff --git a/model_zoo/official/cv/centerface/src/lr_scheduler.py b/model_zoo/official/cv/centerface/src/lr_scheduler.py new file mode 100644 index 0000000000..588207afa3 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/lr_scheduler.py @@ -0,0 +1,851 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate scheduler""" + +import math +from collections import Counter +import numpy as np + +__all__ = ["LambdaLR", "MultiplicativeLR", "StepLR", "MultiStepLR", "ExponentialLR", "CosineAnnealingLR", "CyclicLR", + "CosineAnnealingWarmRestarts", "OneCycleLR", "POLYLR"] + +class _WarmUp(): + """ + Basic class for warm up + """ + def __init__(self, warmup_init_lr): + self.warmup_init_lr = warmup_init_lr + + def get_lr(self, current_step=0): + # Get learning rate during warmup + current_step = 0 + raise NotImplementedError + +class _LinearWarmUp(_WarmUp): + """ + Class for linear warm up + """ + def __init__(self, lr, warmup_epochs, steps_per_epoch, warmup_init_lr=0): + self.base_lr = lr + self.warmup_init_lr = warmup_init_lr + self.warmup_steps = int(warmup_epochs * steps_per_epoch) + + super(_LinearWarmUp, self).__init__(warmup_init_lr) + + def get_warmup_steps(self): + return self.warmup_steps + + def get_lr(self, current_step=0): + lr_inc = (float(self.base_lr) - float(self.warmup_init_lr)) / float(self.warmup_steps) + lr = float(self.warmup_init_lr) + lr_inc * current_step + return lr + +class _ConstWarmUp(_WarmUp): + """ + Class for const warm up + """ + def __init__(self, warmup_init_lr): + super(_ConstWarmUp, self).__init__(warmup_init_lr) + self.warmup_init_lr = warmup_init_lr + + def get_lr(self, current_step=0): + current_step = 0 + return self.warmup_init_lr + +class _LRScheduler(): + """ + Basic class for learning rate scheduler + """ + def __init__(self, lr, max_epoch, steps_per_epoch): + self.base_lr = lr + self.steps_per_epoch = steps_per_epoch + self.total_steps = int(max_epoch * steps_per_epoch) + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + +class LambdaLR(_LRScheduler): + r""" + Lambda learning rate scheduler + + Sets the learning rate to the initial lr times a given function. + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + lr_lambda (func. or list): A function which computes a multiplicative factor given an integer parameter epoch. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> lambda1 = lambda epoch: epoch // 30 + >>> scheduler = LambdaLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0): + self.lr_lambda = lr_lambda + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(LambdaLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + cur_ep = i // self.steps_per_epoch + lr = self.base_lr * self.lr_lambda(cur_ep) + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class MultiplicativeLR(_LRScheduler): + """ + Multiplicative learning rate scheduler + + Multiply the learning rate by the factor given in the specified function. + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + lr_lambda (func. or list): A function which computes a multiplicative factor given an integer parameter epoch. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(lr=0.1, lr_lambda=lambda1, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + def __init__(self, lr, lr_lambda, steps_per_epoch, max_epoch, warmup_epochs=0): + self.lr_lambda = lr_lambda + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(MultiplicativeLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + current_lr = self.base_lr + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + cur_ep = i // self.steps_per_epoch + if i % self.steps_per_epoch == 0 and cur_ep > 0: + current_lr = current_lr * self.lr_lambda(cur_ep) + + lr = current_lr + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class StepLR(_LRScheduler): + """ + Step learning rate scheduler + + Decays the learning rate by gamma every epoch_size epochs. + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + epoch_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(lr=0.1, epoch_size=30, gamma=0.1, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, lr, epoch_size, gamma, steps_per_epoch, max_epoch, warmup_epochs=0): + self.epoch_size = epoch_size + self.gamma = gamma + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(StepLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + cur_ep = i // self.steps_per_epoch + lr = self.base_lr * self.gamma**(cur_ep // self.epoch_size) + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class POLYLR(_LRScheduler): + """ + Poly learning rate scheduler + """ + def __init__(self, lr, steps_per_epoch, max_epoch, end_lr, power): + self.end_lr = end_lr + self.power = power + self.max_epoch = max_epoch + self.lr = lr + self.end_lr = end_lr + super(POLYLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + lr_each_step = [] + total_steps = self.steps_per_epoch * self.max_epoch + for i in range(total_steps): + step_ = min(i, total_steps) + lr_each_step.append((self.lr - self.end_lr) * ((1.0 - step_ / total_steps) ** self.power) + self.end_lr) + print("lr_each_step:", lr_each_step[-1]) + return np.array(lr_each_step).astype(np.float32) + +class MultiStepLR(_LRScheduler): + """ + Multi-step learning rate scheduler + + Decays the learning rate by gamma once the number of epoch reaches one of the milestones. + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(lr=0.1, milestones=[30,80], gamma=0.1, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, lr, milestones, gamma, steps_per_epoch, max_epoch, warmup_epochs=0): + self.milestones = Counter(milestones) + self.gamma = gamma + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(MultiStepLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + current_lr = self.base_lr + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + cur_ep = i // self.steps_per_epoch + if i % self.steps_per_epoch == 0 and cur_ep in self.milestones: + current_lr = current_lr * self.gamma + lr = current_lr + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class ExponentialLR(_LRScheduler): + """ + Exponential learning rate scheduler + + Decays the learning rate of each parameter group by gamma every epoch. + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + gamma (float): Multiplicative factor of learning rate decay. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> scheduler = ExponentialLR(lr=0.1, gamma=0.1, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, lr, gamma, steps_per_epoch, max_epoch, warmup_epochs=0): + self.gamma = gamma + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(ExponentialLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + current_lr = self.base_lr + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + if i % self.steps_per_epoch == 0 and i > 0: + current_lr = current_lr * self.gamma + lr = current_lr + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class CosineAnnealingLR(_LRScheduler): + r""" + Cosine annealing scheduler + + Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` + is set to the initial lr and :math:`t_{cur}` is the number of epochs since the + last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{t_{cur}}{t_{max}}\pi\right)\right), + & t_{cur} \neq (2k+1)t_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{t_{max}}\pi\right)\right), + & t_{cur} = (2k+1)t_{max}. + \end{aligned} + + It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Note: + This only implements the cosine annealing part of SGDR, and not the restarts. + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + t_max (int): Maximum number of iterations. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + eta_min (float, optional): Minimum learning rate. Default: 0. + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> scheduler = CosineAnnealingLR(lr=0.1, t_max=120, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, lr, t_max, steps_per_epoch, max_epoch, warmup_epochs=0, eta_min=0): + self.t_max = t_max + self.eta_min = eta_min + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(CosineAnnealingLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + current_lr = self.base_lr + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + cur_ep = i // self.steps_per_epoch + if i % self.steps_per_epoch == 0 and i > 0: + current_lr = self.eta_min + (self.base_lr - self.eta_min) * \ + (1. + math.cos(math.pi*cur_ep / self.t_max)) / 2 + + lr = current_lr + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class CyclicLR(_LRScheduler): + r""" + Cyclical learning rate scheduler + + Sets the learning rate according to cyclical learning rate policy (CLR). + The policy cycles the learning rate between two boundaries with a constant + frequency, as detailed in the paper `Cyclical Learning Rates for Training + Neural Networks`_. The distance between the two boundaries can be scaled on + a per-iteration or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + max_lr (float): Upper learning rate boundaries in the cycle. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. + Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. + Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training iterations since start of cycle). + Default: 'cycle' + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> scheduler = CyclicLR(lr=0.1, max_lr=1.0, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, + lr, + max_lr, + steps_per_epoch, + max_epoch, + step_size_up=2000, + step_size_down=None, + mode='triangular', + gamma=1., + scale_fn=None, + scale_mode='cycle', + warmup_epochs=0): + + self.max_lr = max_lr + + step_size_up = float(step_size_up) + step_size_down = float(step_size_down) if step_size_down is not None else step_size_up + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ['triangular', 'triangular2', 'exp_range'] \ + and scale_fn is None: + raise ValueError('mode is invalid and scale_fn is None') + + self.mode = mode + self.gamma = gamma + + if scale_fn is None: + if self.mode == 'triangular': + self.scale_fn = self._triangular_scale_fn + self.scale_mode = 'cycle' + elif self.mode == 'triangular2': + self.scale_fn = self._triangular2_scale_fn + self.scale_mode = 'cycle' + elif self.mode == 'exp_range': + self.scale_fn = self._exp_range_scale_fn + self.scale_mode = 'iterations' + else: + self.scale_fn = scale_fn + self.scale_mode = scale_mode + + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(CyclicLR, self).__init__(lr, max_epoch, steps_per_epoch) + + def _triangular_scale_fn(self): + return 1. + + def _triangular2_scale_fn(self, x): + return 1 / (2. ** (x - 1)) + + def _exp_range_scale_fn(self, x): + return self.gamma**(x) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + # Calculates the learning rate at batch index. + cycle = math.floor(1 + i / self.total_size) + x = 1. + i / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + base_height = (self.max_lr - self.base_lr) * scale_factor + if self.scale_mode == 'cycle': + lr = self.base_lr + base_height * self.scale_fn(cycle) + elif self.mode == 'triangular': + lr = self.base_lr + base_height * self.scale_fn() + else: + lr = self.base_lr + base_height * self.scale_fn(i) + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class CosineAnnealingWarmRestarts(_LRScheduler): + r""" + Cosine annealing scheduler with warm restarts + + Set the learning rate using a cosine annealing schedule, where + :math:`\eta_{max}` is set to the initial lr, :math:`t_{cur}` is the + number of epochs since the last restart and :math:`t_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{t_{cur}}{t_{i}}\pi\right)\right) + + When :math:`t_{cur}=t_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`t_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + + Args: + + warmup_epochs (int): The number of epochs to Warmup. + Default: 0 + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + t_0 (int): Number of iterations for the first restart. + t_mult (int, optional): A factor increases :math:`t_{i}` after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> scheduler = CosineAnnealingWarmRestarts(lr=0.1, steps_per_epoch=5000, max_epoch=90, t_0=2) + >>> lr = scheduler.get_lr() + """ + + def __init__(self, lr, steps_per_epoch, max_epoch, t_0, t_mult=1, eta_min=0, warmup_epochs=0): + if t_0 <= 0 or not isinstance(t_0, int): + raise ValueError("Expected positive integer t_0, but got {}".format(t_0)) + if t_mult < 1 or not isinstance(t_mult, int): + raise ValueError("Expected integer t_mult >= 1, but got {}".format(t_mult)) + self.t_0 = t_0 + self.t_i = t_0 + self.t_mult = t_mult + self.eta_min = eta_min + self.t_cur = 0 + + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(CosineAnnealingWarmRestarts, self).__init__(lr, max_epoch, steps_per_epoch) + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + if i % self.steps_per_epoch == 0 and i > 0: + self.t_cur += 1 + if self.t_cur >= self.t_i: + self.t_cur = self.t_cur - self.t_i + self.t_i = self.t_i * self.t_mult + + lr = self.eta_min + (self.base_lr - self.eta_min) * \ + (1 + math.cos(math.pi * self.t_cur / self.t_i)) / 2 + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +class OneCycleLR(_LRScheduler): + r""" + One cycle learning rate scheduler + + Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + This scheduler is not chainable. + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + + Args: + lr (float): Initial learning rate which is the lower boundary in the cycle. + steps_per_epoch (int): The number of steps per epoch to train for. + max_epoch (int): The number of epochs to train for. + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + div_factor (float): Determines the max learning rate via + :math:`max_lr = lr * div_factor` + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + :math:`min_lr = lr / final_div_factor` + Default: 1e4 + warmup_epochs (int, optional): The number of epochs to Warmup. Default: 0 + + Outputs: + numpy.ndarray, shape=(1, steps_per_epoch*max_epoch) + + Example: + >>> scheduler = OneCycleLR(lr=0.1, steps_per_epoch=5000, max_epoch=90) + >>> lr = scheduler.get_lr() + """ + def __init__(self, + lr, + steps_per_epoch, + max_epoch, + pct_start=0.3, + anneal_strategy='cos', + div_factor=25., + final_div_factor=1e4, + warmup_epochs=0): + + self.warmup = _LinearWarmUp(lr, warmup_epochs, steps_per_epoch) + super(OneCycleLR, self).__init__(lr, max_epoch, steps_per_epoch) + + self.step_size_up = float(pct_start * self.total_steps) - 1 + self.step_size_down = float(self.total_steps - self.step_size_up) - 1 + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) + if anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + # Initialize learning rate variables + self.max_lr = lr * div_factor + self.min_lr = lr / final_div_factor + + def _annealing_cos(self, start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + warmup_steps = self.warmup.get_warmup_steps() + + lr_each_step = [] + for i in range(self.total_steps): + if i < warmup_steps: + lr = self.warmup.get_lr(i+1) + else: + if i <= self.step_size_up: + lr = self.anneal_func(self.base_lr, self.max_lr, i / self.step_size_up) + + else: + down_step_num = i - self.step_size_up + lr = self.anneal_func(self.max_lr, self.min_lr, down_step_num / self.step_size_down) + + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + """ + Linear warmup learning rate scheduler + """ + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """ + Warmup step learning rate scheduler + """ + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma**milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0): + """ + Warmup cosine learning rate scheduler + """ + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_v2(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0): + """ + Warmup cosine v2 learning rate scheduler + """ + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + last_lr = 0 + last_epoch_v1 = 0 + + t_max_v2 = int(max_epoch*1/3) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + if i < total_steps*2/3: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2 + last_lr = lr + last_epoch_v1 = last_epoch + else: + base_lr = last_lr + last_epoch = last_epoch-last_epoch_v1 + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / t_max_v2)) / 2 + + lr_each_step.append(lr) + return np.array(lr_each_step).astype(np.float32) + + +def warmup_cosine_annealing_lr_sample(lr, steps_per_epoch, warmup_epochs, max_epoch, t_max, eta_min=0): + """ + Warmup cosine learning rate scheduler sampler + """ + start_sample_epoch = 60 + step_sample = 2 + tobe_sampled_epoch = 60 + end_sampled_epoch = start_sample_epoch + step_sample*tobe_sampled_epoch + max_sampled_epoch = max_epoch+tobe_sampled_epoch + t_max = max_sampled_epoch + + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + total_sampled_steps = int(max_sampled_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + + for i in range(total_sampled_steps): + last_epoch = i // steps_per_epoch + if last_epoch in range(start_sample_epoch, end_sampled_epoch, step_sample): + continue + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi*last_epoch / t_max)) / 2 + lr_each_step.append(lr) + + assert total_steps == len(lr_each_step) + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/official/cv/centerface/src/mobile_v2.py b/model_zoo/official/cv/centerface/src/mobile_v2.py new file mode 100644 index 0000000000..6c0cd269f8 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/mobile_v2.py @@ -0,0 +1,203 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""modified mobilenet_v2 backbone""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops.operations import TensorAdd +from mindspore import Parameter +from mindspore.common.initializer import initializer + +from src.var_init import KaimingNormal + +__all__ = ['MobileNetV2', 'mobilenet_v2', 'DepthWiseConv'] + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + +class DepthWiseConv(nn.Cell): + """ + Depthwise convolution + """ + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): + super(DepthWiseConv, self).__init__() + self.has_bias = has_bias + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, kernel_size=kernel_size, + stride=stride, pad_mode=pad_mode, pad=pad) + self.bias_add = P.BiasAdd() + + weight_shape = [channel_multiplier, in_planes, kernel_size, kernel_size] + self.weight = Parameter(initializer(KaimingNormal(mode='fan_out'), weight_shape), name='weight') + + if has_bias: + bias_shape = [channel_multiplier * in_planes] + self.bias = Parameter(initializer('zeros', bias_shape), name='bias') + else: + self.bias = None + + def construct(self, x): + output = self.depthwise_conv(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + return output + + +class ConvBNReLU(nn.Cell): + """ + Convolution and batchnorm and relu + """ + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + padding = (kernel_size - 1) // 2 + super(ConvBNReLU, self).__init__() + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode="pad", padding=padding, + has_bias=False) + else: + conv = DepthWiseConv(in_planes, kernel_size, stride, pad_mode="pad", pad=padding) + + layers = [conv, nn.BatchNorm2d(out_planes).add_flags_recursive(fp32=True), nn.ReLU6()] #, momentum=0.9 + self.features = nn.SequentialCell(layers) + self.in_planes = in_planes + self.print = P.Print() + + def construct(self, x): + x = self.features(x) + return x + + +class InvertedResidual(nn.Cell): + """ + Inverted residual module + """ + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), + nn.BatchNorm2d(oup).add_flags_recursive(fp32=True) + ]) + + self.conv = nn.SequentialCell(layers) + self.add = TensorAdd() + self.cast = P.Cast() + + def construct(self, x): + identity = x + x = self.conv(x) + if self.use_res_connect: + return self.add(identity, x) + + return x + + +class MobileNetV2(nn.Cell): + """ + MobileNet V2 main class, backbone + + Args: + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + """ + def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + self.feat_id = [1, 2, 4, 6] + self.feat_channel = [] + + # only check the first element, assuming user knows t,c,n,s are required + if inverted_residual_setting is None or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + features = [ConvBNReLU(3, input_channel, stride=2)] + + for index, (t, c, n, s) in enumerate(inverted_residual_setting): + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + if index == 1: + self.need_fp1 = nn.SequentialCell(features) + self.feat_channel.append(output_channel) + features = [] + elif index == 2: + self.need_fp2 = nn.SequentialCell(features) + self.feat_channel.append(output_channel) + features = [] + elif index == 4: + self.need_fp3 = nn.SequentialCell(features) + self.feat_channel.append(output_channel) + features = [] + elif index == 6: + self.need_fp4 = nn.SequentialCell(features) + self.feat_channel.append(output_channel) + features = [] + + + def construct(self, x): + x1 = self.need_fp1(x) + x2 = self.need_fp2(x1) + x3 = self.need_fp3(x2) + x4 = self.need_fp4(x3) + return x1, x2, x3, x4 + +def mobilenet_v2(**kwargs): + return MobileNetV2(**kwargs) diff --git a/model_zoo/official/cv/centerface/src/utils.py b/model_zoo/official/cv/centerface/src/utils.py new file mode 100644 index 0000000000..bbf48bca2f --- /dev/null +++ b/model_zoo/official/cv/centerface/src/utils.py @@ -0,0 +1,247 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""auxiliary functions for train, to print and preload""" + +import math +import logging +import os +import sys +from datetime import datetime +import numpy as np + +from mindspore.train.serialization import load_checkpoint +import mindspore.nn as nn + +from src.mobile_v2 import DepthWiseConv + +def load_backbone(net, ckpt_path, args): + """ + Load backbone + """ + param_dict = load_checkpoint(ckpt_path) + centerface_backbone_prefix = 'base' + mobilev2_backbone_prefix = 'network.backbone' + find_param = [] + not_found_param = [] + + def replace_names(name, replace_name, replace_idx): + names = name.split('.') + if len(names) < 4: + raise "centerface_backbone_prefix name too short" + tmp = names[2] + '.' + names[3] + if replace_name != tmp: + replace_name = tmp + replace_idx += 1 + name = name.replace(replace_name, 'features' + '.' + str(replace_idx)) + return name, replace_name, replace_idx + + replace_name = 'need_fp1.0' + replace_idx = 0 + for name, cell in net.cells_and_names(): + if name.startswith(centerface_backbone_prefix): + name = name.replace(centerface_backbone_prefix, mobilev2_backbone_prefix) + if isinstance(cell, (nn.Conv2d, nn.Dense, DepthWiseConv)): + name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx) + mobilev2_weight = '{}.weight'.format(name) + mobilev2_bias = '{}.bias'.format(name) + if mobilev2_weight in param_dict: + cell.weight.set_data(param_dict[mobilev2_weight].data) + find_param.append(mobilev2_weight) + else: + not_found_param.append(mobilev2_weight) + if mobilev2_bias in param_dict: + cell.bias.set_data(param_dict[mobilev2_bias].data) + find_param.append(mobilev2_bias) + else: + not_found_param.append(mobilev2_bias) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + name, replace_name, replace_idx = replace_names(name, replace_name, replace_idx) + mobilev2_moving_mean = '{}.moving_mean'.format(name) + mobilev2_moving_variance = '{}.moving_variance'.format(name) + mobilev2_gamma = '{}.gamma'.format(name) + mobilev2_beta = '{}.beta'.format(name) + if mobilev2_moving_mean in param_dict: + cell.moving_mean.set_data(param_dict[mobilev2_moving_mean].data) + find_param.append(mobilev2_moving_mean) + else: + not_found_param.append(mobilev2_moving_mean) + if mobilev2_moving_variance in param_dict: + cell.moving_variance.set_data(param_dict[mobilev2_moving_variance].data) + find_param.append(mobilev2_moving_variance) + else: + not_found_param.append(mobilev2_moving_variance) + if mobilev2_gamma in param_dict: + cell.gamma.set_data(param_dict[mobilev2_gamma].data) + find_param.append(mobilev2_gamma) + else: + not_found_param.append(mobilev2_gamma) + if mobilev2_beta in param_dict: + cell.beta.set_data(param_dict[mobilev2_beta].data) + find_param.append(mobilev2_beta) + else: + not_found_param.append(mobilev2_beta) + + args.logger.info('================found_param {}========='.format(len(find_param))) + args.logger.info(find_param) + args.logger.info('================not_found_param {}========='.format(len(not_found_param))) + args.logger.info(not_found_param) + args.logger.info('=====load {} successfully ====='.format(ckpt_path)) + + return net + +def get_param_groups(network): + """ + Get param groups + """ + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + +class DistributedSampler(): + """ + Distributed sampler + """ + def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): + self.dataset = dataset + self.rank = rank + self.group_size = group_size + self.dataset_length = len(self.dataset) + self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size)) + self.total_size = self.num_samples * self.group_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + self.seed = (self.seed + 1) & 0xffffffff + np.random.seed(self.seed) + indices = np.random.permutation(self.dataset_length).tolist() + else: + indices = list(range(len(self.dataset.classes))) + + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + indices = indices[self.rank::self.group_size] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + +class AverageMeter(): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f', tb_writer=None): + self.name = name + self.fmt = fmt + self.reset() + self.tb_writer = tb_writer + self.cur_step = 1 + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + if self.tb_writer is not None: + self.tb_writer.add_scalar(self.name, self.val, self.cur_step) + self.cur_step += 1 + + def __str__(self): + fmtstr = '{name}:{avg' + self.fmt + '}' + return fmtstr.format(**self.__dict__) + +class LOGGER(logging.Logger): + """ + Logger class + """ + def __init__(self, logger_name, rank=0): + super(LOGGER, self).__init__(logger_name) + if rank % 8 == 0: + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + console.setFormatter(formatter) + self.addHandler(console) + + def setup_logging_file(self, log_dir, rank=0): + """ + Setup logging file + """ + self.rank = rank + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank) + self.log_fn = os.path.join(log_dir, log_name) + fh = logging.FileHandler(self.log_fn) + fh.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + fh.setFormatter(formatter) + self.addHandler(fh) + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO): + self._log(logging.INFO, msg, args, **kwargs) + + def save_args(self, args): + self.info('Args:') + args_dict = vars(args) + for key in args_dict.keys(): + # self.info('--> {}: {}'.format(key, args_dict[key])) + self.info('--> %s', key) + self.info('') + + def important_info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and self.rank == 0: + line_width = 2 + important_msg = '\n' + important_msg += ('*'*70 + '\n')*line_width + important_msg += ('*'*line_width + '\n')*2 + important_msg += '*'*line_width + ' '*8 + msg + '\n' + important_msg += ('*'*line_width + '\n')*2 + important_msg += ('*'*70 + '\n')*line_width + self.info(important_msg, *args, **kwargs) + +def get_logger(path, rank): + logger = LOGGER("centerface", rank) + logger.setup_logging_file(path, rank) + return logger diff --git a/model_zoo/official/cv/centerface/src/var_init.py b/model_zoo/official/cv/centerface/src/var_init.py new file mode 100644 index 0000000000..23318af4b0 --- /dev/null +++ b/model_zoo/official/cv/centerface/src/var_init.py @@ -0,0 +1,285 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""weight initilization""" + +import math +import numpy as np + +from mindspore.common import initializer as init +from mindspore.common.initializer import Initializer as MeInitializer +import mindspore.nn as nn +from mindspore import Tensor + +def assignment(arr, num): + """Assign the value of `num` to `arr`.""" + if arr.shape == (): + arr = arr.reshape((1)) + arr[:] = num + arr = arr.reshape(()) + else: + if isinstance(num, np.ndarray): + arr[:] = num[:] + else: + arr[:] = num + return arr + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + if nonlinearity == 'tanh': + return 5.0 / 3 + if nonlinearity == 'relu': + return math.sqrt(2.0) + if nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def _calculate_correct_fan(array, mode): + mode = mode.lower() + valid_modes = ['fan_in', 'fan_out'] + if mode not in valid_modes: + raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(array) + return fan_in if mode == 'fan_in' else fan_out + + +def kaiming_uniform_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Fills the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + + Examples: + >>> w = np.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') + """ + fan = _calculate_correct_fan(arr, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + return np.random.uniform(-bound, bound, arr.shape) + + +def kaiming_normal_(arr, a=0, mode='fan_in', nonlinearity='leaky_relu'): + r"""Fills the input `Tensor` with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + normal distribution. The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + + Examples: + >>> w = np.empty(3, 5) + >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') + """ + fan = _calculate_correct_fan(arr, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + return np.random.normal(0, std, arr.shape) + + +def _calculate_fan_in_and_fan_out(arr): + """ + Calculate fan in and fan out + """ + dimensions = len(arr.shape) + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for array with fewer than 2 dimensions") + + num_input_fmaps = arr.shape[1] + num_output_fmaps = arr.shape[0] + receptive_field_size = 1 + if dimensions > 2: + receptive_field_size = arr[0][0].size + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def xavier_uniform_(arr, gain=1.): + # type: (Tensor, float) -> Tensor + r"""Fills the input `Tensor` with values according to the method + described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform + distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `Tensor` + gain: an optional scaling factor + + Examples: + >>> w = np.empty(3, 5) + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(arr) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return np.random.uniform(-a, a, arr.shape) + + +class XavierUniform(MeInitializer): + def __init__(self, gain=1.): + super(XavierUniform, self).__init__() + self.gain = gain + + def _initialize(self, arr): + tmp = xavier_uniform_(arr, self.gain) + assignment(arr, tmp) + + +class KaimingUniform(MeInitializer): + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingUniform, self).__init__() + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + + def _initialize(self, arr): + tmp = kaiming_uniform_(arr, self.a, self.mode, self.nonlinearity) + assignment(arr, tmp) + + +class KaimingNormal(MeInitializer): + def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'): + super(KaimingNormal, self).__init__() + self.a = a + self.mode = mode + self.nonlinearity = nonlinearity + + def _initialize(self, arr): + tmp = kaiming_normal_(arr, self.a, self.mode, self.nonlinearity) + assignment(arr, tmp) + +class RandomNormal(MeInitializer): + def __init__(self, std=0.001): + super(RandomNormal, self).__init__() + self.std = std + + def _initialize(self, arr): + std = self.std + tmp = np.random.normal(0, std, arr.shape) + assignment(arr, tmp) + +def default_recurisive_init(custom_cell): + """ + Parameters init + """ + np.random.seed(123) + for name, cell in custom_cell.cells_and_names(): + if 'hm' in name or 'wh' in name or 'off' in name or 'kps' in name: + if isinstance(cell, (nn.Conv2d)): + cell.weight.set_data(init.initializer(RandomNormal(), cell.weight.data.shape, + cell.weight.data.dtype).to_tensor()) + if cell.bias is not None: + cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, + cell.bias.data.dtype).to_tensor()) + continue + + if isinstance(cell, (nn.Conv2d)): + cell.weight.set_data(init.initializer(KaimingNormal(mode='fan_out'), + cell.weight.data.shape, + cell.weight.data.dtype).to_tensor()) + if cell.bias is not None: + cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, + cell.bias.data.dtype).to_tensor()) + elif isinstance(cell, nn.Dense): + cell.weight.set_data(init.initializer(KaimingNormal(mode='fan_out'), + cell.weight.data.shape, + cell.weight.data.dtype).to_tensor()) + if cell.bias is not None: + cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, + cell.bias.data.dtype).to_tensor()) + elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)): + cell.gamma.set_data(init.initializer('ones', cell.gamma.data.shape).to_tensor()) + cell.beta.set_data(init.initializer('zeros', cell.beta.data.shape).to_tensor()) + elif isinstance(cell, nn.Conv2dTranspose): + cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5), mode='fan_out'), + cell.weight.data.shape, + cell.weight.data.dtype).to_tensor()) + if cell.bias is not None: + cell.bias.set_data(init.initializer('zeros', cell.bias.data.shape, + cell.bias.data.dtype).to_tensor()) diff --git a/model_zoo/official/cv/centerface/test.py b/model_zoo/official/cv/centerface/test.py new file mode 100644 index 0000000000..029e76a37d --- /dev/null +++ b/model_zoo/official/cv/centerface/test.py @@ -0,0 +1,160 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test centerface example +""" +import os +import time +import argparse +import datetime +import scipy.io as sio + +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.utils import get_logger +from src.var_init import default_recurisive_init +from src.centerface import CenterfaceMobilev2, CenterFaceWithNms +from src.config import ConfigCenterface + +from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector +from dependency.evaluate.eval import evaluation + +dev_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False, + device_target="Ascend", save_graphs=False, device_id=dev_id) + +parser = argparse.ArgumentParser('mindspore coco training') +parser.add_argument('--data_dir', type=str, default='', help='train data dir') +parser.add_argument('--test_model', type=str, default='', help='test model dir') +parser.add_argument('--ground_truth_mat', type=str, default='', help='ground_truth, mat type') +parser.add_argument('--save_dir', type=str, default='', help='save_path for evaluate') +parser.add_argument('--ground_truth_path', type=str, default='', help='ground_truth path, contain all mat file') +parser.add_argument('--eval', type=int, default=0, help='if do eval after test') +parser.add_argument('--eval_script_path', type=str, default='', help='evaluate script path') +parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') +parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') +parser.add_argument('--ckpt_name', type=str, default="", help='input model name') +parser.add_argument('--device_num', type=int, default=1, help='device num for testing') +parser.add_argument('--steps_per_epoch', type=int, default=198, help='steps for each epoch') +parser.add_argument('--start', type=int, default=0, help='start loop number, used to calculate first epoch number') +parser.add_argument('--end', type=int, default=18, help='end loop number, used to calculate last epoch number') + +args, _ = parser.parse_known_args() + +if __name__ == "__main__": + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + if args.ckpt_name != "": + args.start = 0 + args.end = 1 + + for loop in range(args.start, args.end, 1): + network = CenterfaceMobilev2() + default_recurisive_init(network) + + if args.ckpt_name == "": + ckpt_num = loop * args.device_num + args.rank + 1 + ckpt_name = "0-" + str(ckpt_num) + "_" + str(args.steps_per_epoch * ckpt_num) + ".ckpt" + else: + ckpt_name = args.ckpt_name + + test_model = args.test_model + ckpt_name + if not test_model: + args.logger.info('load_model {} none'.format(test_model)) + continue + + if os.path.isfile(test_model): + param_dict = load_checkpoint(test_model) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): + continue + elif key.startswith('centerface_network.'): + param_dict_new[key[19:]] = values + else: + param_dict_new[key] = values + + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(test_model)) + else: + args.logger.info('{} not exists or not a pre-trained file'.format(test_model)) + continue + + train_network_type_nms = 1 # default with num + if train_network_type_nms: + network = CenterFaceWithNms(network) + args.logger.info('train network type with nms') + network.set_train(False) + args.logger.info('finish get network') + + config = ConfigCenterface() + + # test network ----------- + start = time.time() + + ground_truth_mat = sio.loadmat(args.ground_truth_mat) + event_list = ground_truth_mat['event_list'] + file_list = ground_truth_mat['file_list'] + if args.ckpt_name == "": + save_path = args.save_dir + str(ckpt_num) + '/' + else: + save_path = args.save_dir+ '/' + detector = CenterFaceDetector(config, network) + + for index, event in enumerate(event_list): + file_list_item = file_list[index][0] + im_dir = event[0][0] + if not os.path.exists(save_path + im_dir): + os.makedirs(save_path + im_dir) + args.logger.info('save_path + im_dir={}'.format(save_path + im_dir)) + for num, file in enumerate(file_list_item): + im_name = file[0][0] + zip_name = '%s/%s.jpg' % (im_dir, im_name) + img_path = os.path.join(args.data_dir, zip_name) + args.logger.info('img_path={}'.format(img_path)) + + dets = detector.run(img_path)['results'] + + f = open(save_path + im_dir + '/' + im_name + '.txt', 'w') + f.write('{:s}\n'.format('%s/%s.jpg' % (im_dir, im_name))) + f.write('{:d}\n'.format(len(dets))) + for b in dets[1]: + x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4] + f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s)) + f.close() + args.logger.info('event:{}, num:{}'.format(index + 1, num + 1)) + + end = time.time() + args.logger.info("============num {} time {}".format(num, (end-start)*1000)) + start = end + + if args.eval: + args.logger.info('==========start eval===============') + args.logger.info("test output path = {}".format(save_path)) + if os.path.isdir(save_path): + evaluation(save_path, args.ground_truth_path) + else: + args.logger.info('no test output path') + args.logger.info('==========end eval===============') + + if args.ckpt_name != "": + break + + args.logger.info('==========end testing===============') diff --git a/model_zoo/official/cv/centerface/train.py b/model_zoo/official/cv/centerface/train.py new file mode 100644 index 0000000000..d8dfde1b2c --- /dev/null +++ b/model_zoo/official/cv/centerface/train.py @@ -0,0 +1,348 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Train centerface and get network model files(.ckpt) +""" + +import os +import time +import argparse +import datetime +import numpy as np + +from mindspore import context +from mindspore.context import ParallelMode +from mindspore.nn.optim.adam import Adam +from mindspore.nn.optim.momentum import Momentum +from mindspore.nn.optim.sgd import SGD +from mindspore import Tensor +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.train.callback import ModelCheckpoint, RunContext +from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.profiler.profiling import Profiler + +from src.utils import get_logger +from src.utils import AverageMeter +from src.lr_scheduler import warmup_step_lr +from src.lr_scheduler import warmup_cosine_annealing_lr, \ + warmup_cosine_annealing_lr_v2, warmup_cosine_annealing_lr_sample +from src.lr_scheduler import MultiStepLR +from src.var_init import default_recurisive_init +from src.centerface import CenterfaceMobilev2 +from src.utils import load_backbone, get_param_groups +from src.config import ConfigCenterface +from src.centerface import CenterFaceWithLossCell, TrainingWrapper +from src.dataset import GetDataLoader + +dev_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False, + device_target="Ascend", save_graphs=False, device_id=dev_id, reserve_class_name_in_scope=False) + +parser = argparse.ArgumentParser('mindspore coco training') + +# dataset related +parser.add_argument('--data_dir', type=str, default='', help='train data dir') +parser.add_argument('--annot_path', type=str, default='', help='train data annotation path') +parser.add_argument('--img_dir', type=str, default='', help='train data img dir') +parser.add_argument('--per_batch_size', default=8, type=int, help='batch size for per gpu') + +# network related +parser.add_argument('--pretrained_backbone', default='', type=str, help='model_path, local pretrained backbone' + ' model to load') +parser.add_argument('--resume', default='', type=str, help='path of pretrained centerface_model') + +# optimizer and lr related +parser.add_argument('--lr_scheduler', default='multistep', type=str, + help='lr-scheduler, option type: exponential, cosine_annealing') +parser.add_argument('--lr', default=4e-3, type=float, help='learning rate of the training') +parser.add_argument('--lr_epochs', type=str, default='90,120', help='epoch of lr changing') +parser.add_argument('--lr_gamma', type=float, default=0.1, + help='decrease lr by a factor of exponential lr_scheduler') +parser.add_argument('--eta_min', type=float, default=0., help='eta_min in cosine_annealing scheduler') +parser.add_argument('--t_max', type=int, default=140, help='T-max in cosine_annealing scheduler') +parser.add_argument('--max_epoch', type=int, default=140, help='max epoch num to train the model') +parser.add_argument('--warmup_epochs', default=0, type=float, help='warmup epoch') +parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--optimizer', default='adam', type=str, + help='optimizer type, default: adam') + +# loss related +parser.add_argument('--loss_scale', type=int, default=1024, help='static loss scale') +parser.add_argument('--label_smooth', type=int, default=0, help='whether to use label smooth in CE') +parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='smooth strength of original one-hot') + + # logging related +parser.add_argument('--log_interval', type=int, default=100, help='logging interval') +parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location') +parser.add_argument('--ckpt_interval', type=int, default=None, help='ckpt_interval') + +parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank') + +# distributed related +parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') +parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') +parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') + +# roma obs +parser.add_argument('--train_url', type=str, default="", help='train url') + +# profiler init, can open when you debug. if train, donot open, since it cost memory and disk space +parser.add_argument('--need_profiler', type=int, default=0, help='whether use profiler') + +# reset default config +parser.add_argument('--training_shape', type=str, default="", help='fix training shape') +parser.add_argument('--resize_rate', type=int, default=None, help='resize rate for multi-scale training') + +args, _ = parser.parse_known_args() + +if args.lr_scheduler == 'cosine_annealing' and args.max_epoch > args.t_max: + args.t_max = args.max_epoch + +args.lr_epochs = list(map(int, args.lr_epochs.split(','))) + + +def convert_training_shape(args_): + """ + Convert training shape + """ + training_shape = [int(args_.training_shape), int(args_.training_shape)] + return training_shape + + +if __name__ == "__main__": + # init distributed + if args.is_distributed: + init() + args.rank = get_rank() + args.group_size = get_group_size() + + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + + if args.need_profiler: + profiler = Profiler(output_path=args.outputs_dir) + + loss_meter = AverageMeter('loss') + + context.reset_auto_parallel_context() + if args.is_distributed: + parallel_mode = ParallelMode.DATA_PARALLEL + degree = get_group_size() + else: + parallel_mode = ParallelMode.STAND_ALONE + degree = 1 + + # context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=degree, parameter_broadcast=True, gradients_mean=True) + # Notice: parameter_broadcast should be supported, but current version has bugs, thus been disabled. + # To make sure the init weight on all npu is the same, we need to set a static seed in default_recurisive_init when weight initialization + context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree) + network = CenterfaceMobilev2() + # init, to avoid overflow, some std of weight should be enough small + default_recurisive_init(network) + + if args.pretrained_backbone: + network = load_backbone(network, args.pretrained_backbone, args) + args.logger.info('load pre-trained backbone {} into network'.format(args.pretrained_backbone)) + else: + args.logger.info('Not load pre-trained backbone, please be careful') + + if os.path.isfile(args.resume): + param_dict = load_checkpoint(args.resume) + param_dict_new = {} + for key, values in param_dict.items(): + if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): + continue + elif key.startswith('centerface_network.'): + param_dict_new[key[19:]] = values + else: + param_dict_new[key] = values + + load_param_into_net(network, param_dict_new) + args.logger.info('load_model {} success'.format(args.resume)) + else: + args.logger.info('{} not set/exists or not a pre-trained file'.format(args.resume)) + + network = CenterFaceWithLossCell(network) + args.logger.info('finish get network') + + config = ConfigCenterface() + config.data_dir = args.data_dir + config.annot_path = args.annot_path + config.img_dir = args.img_dir + + config.label_smooth = args.label_smooth + config.label_smooth_factor = args.label_smooth_factor + # -------------reset config----------------- + if args.training_shape: + config.multi_scale = [convert_training_shape(args)] + + if args.resize_rate: + config.resize_rate = args.resize_rate + + # data loader + data_loader, args.steps_per_epoch = GetDataLoader(per_batch_size=args.per_batch_size, + max_epoch=args.max_epoch, + rank=args.rank, + group_size=args.group_size, + config=config, + split='train') + args.steps_per_epoch = args.steps_per_epoch // args.max_epoch + args.logger.info('Finish loading dataset') + + if not args.ckpt_interval: + args.ckpt_interval = args.steps_per_epoch + + # lr scheduler + if args.lr_scheduler == 'multistep': + lr_fun = MultiStepLR(args.lr, args.lr_epochs, args.lr_gamma, args.steps_per_epoch, args.max_epoch, + args.warmup_epochs) + lr = lr_fun.get_lr() + elif args.lr_scheduler == 'exponential': + lr = warmup_step_lr(args.lr, + args.lr_epochs, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + gamma=args.lr_gamma + ) + elif args.lr_scheduler == 'cosine_annealing': + lr = warmup_cosine_annealing_lr(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.t_max, + args.eta_min) + elif args.lr_scheduler == 'cosine_annealing_V2': + lr = warmup_cosine_annealing_lr_v2(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.t_max, + args.eta_min) + elif args.lr_scheduler == 'cosine_annealing_sample': + lr = warmup_cosine_annealing_lr_sample(args.lr, + args.steps_per_epoch, + args.warmup_epochs, + args.max_epoch, + args.t_max, + args.eta_min) + else: + raise NotImplementedError(args.lr_scheduler) + + if args.optimizer == "adam": + opt = Adam(params=get_param_groups(network), + learning_rate=Tensor(lr), + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + args.logger.info("use adam optimizer") + elif args.optimizer == "sgd": + opt = SGD(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + else: + opt = Momentum(params=get_param_groups(network), + learning_rate=Tensor(lr), + momentum=args.momentum, + weight_decay=args.weight_decay, + loss_scale=args.loss_scale) + + network = TrainingWrapper(network, opt, sens=args.loss_scale) + network.set_train() + + if args.rank_save_ckpt_flag: + # checkpoint save + ckpt_max_num = args.max_epoch * args.steps_per_epoch // args.ckpt_interval + ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval, + keep_checkpoint_max=ckpt_max_num) + ckpt_cb = ModelCheckpoint(config=ckpt_config, + directory=args.outputs_dir, + prefix='{}'.format(args.rank)) + cb_params = _InternalCallbackParam() + cb_params.train_network = network + cb_params.epoch_num = ckpt_max_num + cb_params.cur_epoch_num = 1 + run_context = RunContext(cb_params) + ckpt_cb.begin(run_context) + + args.logger.info('args.steps_per_epoch = {} args.ckpt_interval ={}'.format(args.steps_per_epoch, + args.ckpt_interval)) + + t_end = time.time() + + for i_all, batch_load in enumerate(data_loader): + i = i_all % args.steps_per_epoch + epoch = i_all // args.steps_per_epoch + 1 + images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks = batch_load + + images = Tensor(images) + hm = Tensor(hm) + reg_mask = Tensor(reg_mask) + ind = Tensor(ind) + wh = Tensor(wh) + wight_mask = Tensor(wight_mask) + hm_offset = Tensor(hm_offset) + hps_mask = Tensor(hps_mask) + landmarks = Tensor(landmarks) + + loss, overflow, scaling = network(images, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks) + # Tensor to numpy + overflow = np.all(overflow.asnumpy()) + loss = loss.asnumpy() + loss_meter.update(loss) + args.logger.info('epoch:{}, iter:{}, avg_loss:{}, loss:{}, overflow:{}, loss_scale:{}'.format(epoch, + i, + loss_meter, + loss, + overflow, + scaling.asnumpy() + )) + + if args.rank_save_ckpt_flag: + # ckpt progress + cb_params.cur_epoch_num = epoch + cb_params.cur_step_num = i + 1 + (epoch-1)*args.steps_per_epoch + cb_params.batch_num = i + 2 + (epoch-1)*args.steps_per_epoch + ckpt_cb.step_end(run_context) + + if (i_all+1) % args.steps_per_epoch == 0: + time_used = time.time() - t_end + fps = args.per_batch_size * args.steps_per_epoch * args.group_size / time_used + if args.rank == 0: + args.logger.info( + 'epoch[{}], {}, {:.2f} imgs/sec, lr:{}' + .format(epoch, loss_meter, fps, lr[i + (epoch-1)*args.steps_per_epoch]) + ) + t_end = time.time() + loss_meter.reset() + + if args.need_profiler: + profiler.analyse() + + args.logger.info('==========end training===============')