diff --git a/model_zoo/official/cv/ctpn/README.md b/model_zoo/official/cv/ctpn/README.md new file mode 100644 index 0000000000..420f7a8f24 --- /dev/null +++ b/model_zoo/official/cv/ctpn/README.md @@ -0,0 +1,293 @@ +![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png) + + + +# CTPN for Ascend + +- [CTPN Description](#CTPN-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [CTPN Description](#contents) + +CTPN is a text detection model based on object detection method. It improves Faster R-CNN and combines with bidirectional LSTM, so ctpn is very effective for horizontal text detection. Another highlight of ctpn is to transform the text detection task into a series of small-scale text box detection.This idea was proposed in the paper "Detecting Text in Natural Image with Connectionist Text Proposal Network". + +[Paper](https://arxiv.org/pdf/1609.03605.pdf) Zhi Tian, Weilin Huang, Tong He, Pan He, Yu Qiao, "Detecting Text in Natural Image with Connectionist Text Proposal Network", ArXiv, vol. abs/1609.03605, 2016. + +# [Model architecture](#contents) + +The overall network architecture contains a VGG16 as backbone, and use bidirection lstm to extract context feature of the small-scale text box, then it used the RPN(RegionProposal Network) to predict the boundding box and probability. + +[Link](https://arxiv.org/pdf/1605.07314v1.pdf) + +# [Dataset](#contents) + +Here we used 6 datasets for training, and 1 datasets for Evaluation. + +- Dataset1: ICDAR 2013: Focused Scene Text + - Train: 142MB, 229 images + - Test: 110MB, 233 images +- Dataset2: ICDAR 2011: Born-Digital Images + - Train: 27.7MB, 410 images +- Dataset3: ICDAR 2015: + - Train:89MB, 1000 images +- Dataset4: SCUT-FORU: Flickr OCR Universal Database + - Train: 388MB, 1715 images +- Dataset5: CocoText v2(Subset of MSCOCO2017): + - Train: 13GB, 63686 images +- Dataset6: SVT(The Street View Dataset) + - Train: 115MB, 349 images + +# [Features](#contents) + +# [Environment Requirements](#contents) + +- Hardware(Ascend) + - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [Script description](#contents) + +## [Script and sample code](#contents) + +```shell +. +└─ctpn + ├── README.md # network readme + ├── eval.py # eval net + ├── scripts + │   ├── eval_res.sh # calculate precision and recall + │   ├── run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p) + │   ├── run_eval_ascend.sh # launch evaluating with ascend platform + │   └── run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p) + ├── src + │   ├── CTPN + │   │   ├── BoundingBoxDecode.py # bounding box decode + │   │   ├── BoundingBoxEncode.py # bounding box encode + │   │   ├── __init__.py # package init file + │   │   ├── anchor_generator.py # anchor generator + │   │   ├── bbox_assign_sample.py # proposal layer + │   │   ├── proposal_generator.py # proposla generator + │   │   ├── rpn.py # region-proposal network + │   │   └── vgg16.py # backbone + │   ├── config.py # training configuration + │   ├── convert_icdar2015.py # convert icdar2015 dataset label + │   ├── convert_svt.py # convert svt label + │   ├── create_dataset.py # create mindrecord dataset + │   ├── ctpn.py # ctpn network definition + │   ├── dataset.py # data proprocessing + │   ├── lr_schedule.py # learning rate scheduler + │   ├── network_define.py # network definition + │   └── text_connector + │   ├── __init__.py # package init file + │   ├── connect_text_lines.py # connect text lines + │   ├── detector.py # detect box + │   ├── get_successions.py # get succession proposal + │   └── utils.py # some functions which is commonly used + └── train.py # train net + +``` + +## [Training process](#contents) + +### Dataset + +To create dataset, download the dataset first and deal with it.We provided src/convert_svt.py and src/convert_icdar2015.py to deal with svt and icdar2015 dataset label.For svt dataset, you can deal with it as below: + +```shell + python convert_svt.py --dataset_path=/path/img --xml_file=/path/train.xml --location_dir=/path/location +``` + +For ICDAR2015 dataset, you can deal with it + +```shell + python convert_icdar2015.py --src_label_path=/path/train_label --target_label_path=/path/label +``` + +Then modify the src/config.py and add the dataset path.For each path, add IMAGE_PATH and LABEL_PATH into a list in config.An example is show as blow: + +```python + # create dataset + "coco_root": "/path/coco", + "coco_train_data_type": "train2017", + "cocotext_json": "/path/cocotext.v2.json", + "icdar11_train_path": ["/path/image/", "/path/label"], + "icdar13_train_path": ["/path/image/", "/path/label"], + "icdar15_train_path": ["/path/image/", "/path/label"], + "icdar13_test_path": ["/path/image/", "/path/label"], + "flick_train_path": ["/path/image/", "/path/label"], + "svt_train_path": ["/path/image/", "/path/label"], + "pretrain_dataset_path": "", + "finetune_dataset_path": "", + "test_dataset_path": "", +``` + +Then you can create dataset with src/create_dataset.py with the command as below: + +```shell +python src/create_dataset.py +``` + +### Usage + +- Ascend: + +```bash +# distribute training example(8p) +sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH] +# standalone training +sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] +# evaluation: +sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH] +``` + +The `pretrained_path` should be a checkpoint of vgg16 trained on Imagenet2012. The name of weight in dict should be totally the same, also the batch_norm should be enabled in the trainig of vgg16, otherwise fails in further steps.COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text).To get the vgg16 backbone, you can use the network structure defined in src/CTPN/vgg16.py.To train the backbone, copy the src/CTPN/vgg16.py under modelzoo/official/cv/vgg16/src/, and modify the vgg16/train.py to suit the new construction.You can fix it as below: + +```python +... +from src.vgg16 import VGG16 +... +network = VGG16() +... + +``` + +Then you can train it with ImageNet2012. +> Notes: +> RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html) , and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). For large models like InceptionV4, it's better to export an external environment variable `export HCCL_CONNECT_TIMEOUT=600` to extend hccl connection checking time from the default 120 seconds to 600 seconds. Otherwise, the connection could be timeout since compiling time increases with the growth of model size. +> +> This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh` +> +> TASK_TYPE contains Pretraining and Finetune. For Pretraining, we use ICDAR2013, ICDAR2015, SVT, SCUT-FORU, CocoText v2. For Finetune, we use ICDAR2011, +ICDAR2013, SCUT-FORU to improve precision and recall, and when doing Finetune, we use the checkpoint training in Pretrain as our PRETRAINED_PATH. +> COCO_TEXT_PARSER_PATH coco_text.py can refer to [Link](https://github.com/andreasveit/coco-text). +> + +### Launch + +```bash +# training example + shell: + Ascend: + # distribute training example(8p) + sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH] + # standalone training + sh run_standalone_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH] +``` + +### Result + +Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log`, also the loss will be redirected to `./loss_0.log` like followings. + +```python +377 epoch: 1 step: 229 ,rpn_loss: 0.00355, rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00103, +399 epoch: 2 step: 229 ,rpn_loss: 0.00327,rpn_cls_loss: 0.00047, rpn_reg_loss: 0.00093, +424 epoch: 3 step: 229 ,rpn_loss: 0.00910, rpn_cls_loss: 0.00385, rpn_reg_loss: 0.00175, +``` + +## [Eval process](#contents) + +### Usage + +You can start training using python or shell scripts. The usage of shell scripts as follows: + +- Ascend: + +```bash + sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH] +``` + +After eval, you can get serval archive file named submit_ctpn-xx_xxxx.zip, which contains the name of your checkpoint file.To evalulate it, you can use the scripts provided by the ICDAR2013 network, you can download the Deteval scripts from the [link](https://rrc.cvc.uab.es/?com=downloads&action=download&ch=2&f=aHR0cHM6Ly9ycmMuY3ZjLnVhYi5lcy9zdGFuZGFsb25lcy9zY3JpcHRfdGVzdF9jaDJfdDFfZTItMTU3Nzk4MzA2Ny56aXA=) +After download the scripts, unzip it and put it under ctpn/scripts and use eval_res.sh to get the result.You will get files as below: + +```text +gt.zip +readme.txt +rrc_evalulation_funcs_1_1.py +script.py +``` + +Then you can run the scripts/eval_res.sh to calculate the evalulation result. + +```base +bash eval_res.sh +``` + +### Result + +Evaluation result will be stored in the example path, you can find result like the followings in `log`. + +```text +{"precision": 0.90791, "recall": 0.86118, "hmean": 0.88393} +``` + +# [Model description](#contents) + +## [Performance](#contents) + +### Training Performance + +| Parameters | Ascend | +| -------------------------- | ------------------------------------------------------------ | +| Model Version | CTPN | +| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | +| uploaded Date | 02/06/2021 | +| MindSpore Version | 1.1.1 | +| Dataset | 16930 images | +| Batch_size | 2 | +| Training Parameters | src/config.py | +| Optimizer | Momentum | +| Loss Function | SoftmaxCrossEntropyWithLogits for classification, SmoothL2Loss for bbox regression| +| Loss | ~0.04 | +| Total time (8p) | 6h | +| Scripts | [ctpn script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/ctpn) | + +#### Inference Performance + +| Parameters | Ascend | +| ------------------- | --------------------------- | +| Model Version | CTPN | +| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | +| Uploaded Date | 02/06/2020 | +| MindSpore Version | 1.1.1 | +| Dataset | 229 images | +| Batch_size | 1 | +| Accuracy | precision=0.9079, recall=0.8611 F-measure:0.8839 | +| Total time | 1 min | +| Model for inference | 135M (.ckpt file) | + +#### Training performance results + +| **Ascend** | train performance | +| :--------: | :---------------: | +| 1p | 10 img/s | + +| **Ascend** | train performance | +| :--------: | :---------------: | +| 8p | 84 img/s | + +# [Description of Random Situation](#contents) + +We set seed to 1 in train.py. + +# [ModelZoo Homepage](#contents) + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/ctpn/eval.py b/model_zoo/official/cv/ctpn/eval.py new file mode 100644 index 0000000000..ddb8509252 --- /dev/null +++ b/model_zoo/official/cv/ctpn/eval.py @@ -0,0 +1,118 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Evaluation for CTPN""" +import os +import argparse +import time +import numpy as np +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed +from src.ctpn import CTPN +from src.config import config +from src.dataset import create_ctpn_dataset +from src.text_connector.detector import detect +set_seed(1) + +parser = argparse.ArgumentParser(description="CTPN evaluation") +parser.add_argument("--dataset_path", type=str, default="", help="Dataset path.") +parser.add_argument("--image_path", type=str, default="", help="Image path.") +parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path.") +parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") +args_opt = parser.parse_args() +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) + +def ctpn_infer_test(dataset_path='', ckpt_path='', img_dir=''): + """ctpn infer.""" + print("ckpt path is {}".format(ckpt_path)) + ds = create_ctpn_dataset(dataset_path, batch_size=config.test_batch_size, repeat_num=1, is_training=False) + config.batch_size = config.test_batch_size + total = ds.get_dataset_size() + print("*************total dataset size is {}".format(total)) + net = CTPN(config, is_training=False) + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + net.set_train(False) + eval_iter = 0 + + print("\n========================================\n") + print("Processing, please wait a moment.") + img_basenames = [] + output_dir = os.path.join(os.getcwd(), "submit") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + for file in os.listdir(img_dir): + img_basenames.append(os.path.basename(file)) + for data in ds.create_dict_iterator(): + img_data = data['image'] + img_metas = data['image_shape'] + gt_bboxes = data['box'] + gt_labels = data['label'] + gt_num = data['valid_num'] + + start = time.time() + # run net + output = net(img_data, img_metas, gt_bboxes, gt_labels, gt_num) + gt_bboxes = gt_bboxes.asnumpy() + gt_labels = gt_labels.asnumpy() + gt_num = gt_num.asnumpy().astype(bool) + end = time.time() + proposal = output[0] + proposal_mask = output[1] + print("start to draw pic") + for j in range(config.test_batch_size): + img = img_basenames[config.test_batch_size * eval_iter + j] + all_box_tmp = proposal[j].asnumpy() + all_mask_tmp = np.expand_dims(proposal_mask[j].asnumpy(), axis=1) + using_boxes_mask = all_box_tmp * all_mask_tmp + textsegs = using_boxes_mask[:, 0:4].astype(np.float32) + scores = using_boxes_mask[:, 4].astype(np.float32) + shape = img_metas.asnumpy()[0][:2].astype(np.int32) + bboxes = detect(textsegs, scores[:, np.newaxis], shape) + from PIL import Image, ImageDraw + im = Image.open(img_dir + '/' + img) + draw = ImageDraw.Draw(im) + image_h = img_metas.asnumpy()[j][2] + image_w = img_metas.asnumpy()[j][3] + gt_boxs = gt_bboxes[j][gt_num[j], :] + for gt_box in gt_boxs: + gt_x1 = gt_box[0] / image_w + gt_y1 = gt_box[1] / image_h + gt_x2 = gt_box[2] / image_w + gt_y2 = gt_box[3] / image_h + draw.line([(gt_x1, gt_y1), (gt_x1, gt_y2), (gt_x2, gt_y2), (gt_x2, gt_y1), (gt_x1, gt_y1)],\ + fill='green', width=2) + file_name = "res_" + img.replace("jpg", "txt") + output_file = os.path.join(output_dir, file_name) + f = open(output_file, 'w') + for bbox in bboxes: + x1 = bbox[0] / image_w + y1 = bbox[1] / image_h + x2 = bbox[2] / image_w + y2 = bbox[3] / image_h + draw.line([(x1, y1), (x1, y2), (x2, y2), (x2, y1), (x1, y1)], fill='red', width=2) + str_tmp = str(int(x1)) + "," + str(int(y1)) + "," + str(int(x2)) + "," + str(int(y2)) + f.write(str_tmp) + f.write("\n") + f.close() + im.save(img) + percent = round(eval_iter / total * 100, 2) + eval_iter = eval_iter + 1 + print("Iter {} cost time {}".format(eval_iter, end - start)) + print(' %s [%d/%d]' % (str(percent) + '%', eval_iter, total), end='\r') + +if __name__ == '__main__': + ctpn_infer_test(args_opt.dataset_path, args_opt.checkpoint_path, img_dir=args_opt.image_path) diff --git a/model_zoo/official/cv/ctpn/scripts/eval_res.sh b/model_zoo/official/cv/ctpn/scripts/eval_res.sh new file mode 100644 index 0000000000..1e4221fead --- /dev/null +++ b/model_zoo/official/cv/ctpn/scripts/eval_res.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +for submit_file in "submit"*.zip +do + echo "eval result for ${submit_file}" + python script.py –g=gt.zip –s=${submit_file} –o=./ + echo -e ".\n" +done diff --git a/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh b/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh new file mode 100644 index 0000000000..98b5cee57d --- /dev/null +++ b/model_zoo/official/cv/ctpn/scripts/run_distribute_train_ascend.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 3 ] +then + echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [TASK_TYPE] [PRETRAINED_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 + +if [ ! -f $PATH1 ] +then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" +exit 1 +fi +TASK_TYPE=$2 +PATH2=$(get_real_path $3) +echo $PATH2 +if [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +for((i=0; i<${DEVICE_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env > env.log + python train.py --device_id=$i --rank_id=$i --run_distribute=True --device_num=$DEVICE_NUM --task_type=$TASK_TYPE --pre_trained=$PATH2 &> log & + cd .. +done diff --git a/model_zoo/official/cv/ctpn/scripts/run_eval_ascend.sh b/model_zoo/official/cv/ctpn/scripts/run_eval_ascend.sh new file mode 100644 index 0000000000..0a0d49ab02 --- /dev/null +++ b/model_zoo/official/cv/ctpn/scripts/run_eval_ascend.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 3 ] +then + echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +IMAGE_PATH=$(get_real_path $1) +DATASET_PATH=$(get_real_path $2) +CHECKPOINT_PATH=$(get_real_path $3) +echo $IMAGE_PATH +echo $DATASET_PATH +echo $CHECKPOINT_PATH + +if [ ! -d $IMAGE_PATH ] +then + echo "error: IMAGE_PATH=$PATH1 is not a path" +exit 1 +fi + +if [ ! -f $DATASET_PATH ] +then + echo "error: CHECKPOINT_PATH=$DATASET_PATH is not a path" +exit 1 +fi + +if [ ! -d $CHECKPOINT_PATH ] +then + echo "error: CHECKPOINT_PATH=$CHECKPOINT_PATH is not a directory" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export RANK_SIZE=$DEVICE_NUM +export DEVICE_ID=0 +export RANK_ID=0 +for file in "${CHECKPOINT_PATH}"/*.ckpt +do + if [ -d "eval" ]; + then + rm -rf ./eval + fi + mkdir ./eval + cp ../*.py ./eval + cp *.sh ./eval + cp -r ../src ./eval + cd ./eval + env > env.log + CHECKPOINT_FILE_PATH=$file + echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" + python eval.py --device_id=$DEVICE_ID --image_path=$IMAGE_PATH --dataset_path=$DATASET_PATH --checkpoint_path=$CHECKPOINT_FILE_PATH &> log + echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}" + cd ./submit + file_base_name=$(basename $file) + zip -r ../../submit_${file_base_name%.*}.zip *.txt + cd ../../ +done diff --git a/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh new file mode 100644 index 0000000000..1649590a04 --- /dev/null +++ b/model_zoo/official/cv/ctpn/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# -ne 2 ] +then + echo "Usage: sh run_distribute_train_ascend.sh [TASK_TYPE] [PRETRAINED_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +TASK_TYPE=$1 +PRETRAINED_PATH=$(get_real_path $2) +echo $PRETRAINED_PATH +if [ ! -f $PRETRAINED_PATH ] +then + echo "error: PRETRAINED_PATH=$PRETRAINED_PATH is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +rm -rf ./train +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log +python train.py --device_id=$DEVICE_ID --task_type=$TASK_TYPE --pre_trained=$PRETRAINED_PATH &> log & +cd .. diff --git a/model_zoo/official/cv/ctpn/src/CTPN/BoundingBoxDecode.py b/model_zoo/official/cv/ctpn/src/CTPN/BoundingBoxDecode.py new file mode 100644 index 0000000000..de277b2640 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/BoundingBoxDecode.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +from mindspore.ops import operations as P + +class BoundingBoxDecode(nn.Cell): + """ + BoundintBox Decoder. + + Returns: + pred_box(Tensor): decoder bounding boxes. + """ + def __init__(self): + super(BoundingBoxDecode, self).__init__() + self.split = P.Split(axis=1, output_num=4) + self.ones = 1.0 + self.half = 0.5 + self.log = P.Log() + self.exp = P.Exp() + self.concat = P.Concat(axis=1) + + def construct(self, bboxes, deltas): + """ + boxes(Tensor): boundingbox. + deltas(Tensor): delta between boundingboxs and anchors. + """ + x1, y1, x2, y2 = self.split(bboxes) + width = x2 - x1 + self.ones + height = y2 - y1 + self.ones + ctr_x = x1 + self.half * width + ctr_y = y1 + self.half * height + _, dy, _, dh = self.split(deltas) + pred_ctr_x = ctr_x + pred_ctr_y = dy * height + ctr_y + pred_w = width + pred_h = self.exp(dh) * height + + x1 = pred_ctr_x - self.half * pred_w + y1 = pred_ctr_y - self.half * pred_h + x2 = pred_ctr_x + self.half * pred_w + y2 = pred_ctr_y + self.half * pred_h + pred_box = self.concat((x1, y1, x2, y2)) + return pred_box diff --git a/model_zoo/official/cv/ctpn/src/CTPN/BoundingBoxEncode.py b/model_zoo/official/cv/ctpn/src/CTPN/BoundingBoxEncode.py new file mode 100644 index 0000000000..33b8e14dcd --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/BoundingBoxEncode.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +from mindspore.ops import operations as P + +class BoundingBoxEncode(nn.Cell): + """ + BoundintBox Decoder. + + Returns: + pred_box(Tensor): decoder bounding boxes. + """ + def __init__(self): + super(BoundingBoxEncode, self).__init__() + self.split = P.Split(axis=1, output_num=4) + self.ones = 1.0 + self.half = 0.5 + self.log = P.Log() + self.concat = P.Concat(axis=1) + def construct(self, anchor_box, gt_box): + """ + boxes(Tensor): boundingbox. + deltas(Tensor): delta between boundingboxs and anchors. + """ + x1, y1, x2, y2 = self.split(anchor_box) + width = x2 - x1 + self.ones + height = y2 - y1 + self.ones + ctr_x = x1 + self.half * width + ctr_y = y1 + self.half * height + gt_x1, gt_y1, gt_x2, gt_y2 = self.split(gt_box) + gt_width = gt_x2 - gt_x1 + self.ones + gt_height = gt_y2 - gt_y1 + self.ones + ctr_gt_x = gt_x1 + self.half * gt_width + ctr_gt_y = gt_y1 + self.half * gt_height + + target_dx = (ctr_gt_x - ctr_x) / width + target_dy = (ctr_gt_y - ctr_y) / height + dw = gt_width / width + dh = gt_height / height + target_dw = self.log(dw) + target_dh = self.log(dh) + deltas = self.concat((target_dx, target_dy, target_dw, target_dh)) + return deltas diff --git a/model_zoo/official/cv/ctpn/src/CTPN/__init__.py b/model_zoo/official/cv/ctpn/src/CTPN/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py b/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py new file mode 100644 index 0000000000..c5c26c28eb --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FasterRcnn anchor generator.""" +import numpy as np +class AnchorGenerator(): + """Anchor generator for FasterRcnn.""" + def __init__(self, config): + """Anchor generator init method.""" + self.base_size = config.anchor_base + self.num_anchor = config.num_anchors + self.anchor_height = config.anchor_height + self.anchor_width = config.anchor_width + self.size = self.gen_anchor_size() + self.base_anchors = self.gen_base_anchors() + + def gen_base_anchors(self): + """Generate a single anchor.""" + base_anchor = np.array([0, 0, self.base_size - 1, self.base_size - 1], np.int32) + anchors = np.zeros((len(self.size), 4), np.int32) + index = 0 + for h, w in self.size: + anchors[index] = self.scale_anchor(base_anchor, h, w) + index += 1 + return anchors + + def gen_anchor_size(self): + """Generate a list of anchor size""" + size = [] + for width in self.anchor_width: + for height in self.anchor_height: + size.append((height, width)) + return size + + def scale_anchor(self, anchor, h, w): + x_ctr = (anchor[0] + anchor[2]) * 0.5 + y_ctr = (anchor[1] + anchor[3]) * 0.5 + scaled_anchor = anchor.copy() + scaled_anchor[0] = x_ctr - w / 2 # xmin + scaled_anchor[2] = x_ctr + w / 2 # xmax + scaled_anchor[1] = y_ctr - h / 2 # ymin + scaled_anchor[3] = y_ctr + h / 2 # ymax + return scaled_anchor + + def _meshgrid(self, x, y): + """Generate grid.""" + xx = np.repeat(x.reshape(1, len(x)), len(y), axis=0).reshape(-1) + yy = np.repeat(y, len(x)) + return xx, yy + + def grid_anchors(self, featmap_size, stride=16): + """Generate anchor list.""" + base_anchors = self.base_anchors + feat_h, feat_w = featmap_size + shift_x = np.arange(0, feat_w) * stride + shift_y = np.arange(0, feat_h) * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + shifts = np.stack([shift_xx, shift_yy, shift_xx, shift_yy], axis=-1) + shifts = shifts.astype(base_anchors.dtype) + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + all_anchors = all_anchors.reshape(-1, 4) + return all_anchors diff --git a/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py b/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py new file mode 100644 index 0000000000..18d2aa1130 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py @@ -0,0 +1,152 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FasterRcnn positive and negative sample screening for RPN.""" + +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype +from src.CTPN.BoundingBoxEncode import BoundingBoxEncode + + +class BboxAssignSample(nn.Cell): + """ + Bbox assigner and sampler definition. + + Args: + config (dict): Config. + batch_size (int): Batchsize. + num_bboxes (int): The anchor nums. + add_gt_as_proposals (bool): add gt bboxes as proposals flag. + + Returns: + Tensor, output tensor. + bbox_targets: bbox location, (batch_size, num_bboxes, 4) + bbox_weights: bbox weights, (batch_size, num_bboxes, 1) + labels: label for every bboxes, (batch_size, num_bboxes, 1) + label_weights: label weight for every bboxes, (batch_size, num_bboxes, 1) + + Examples: + BboxAssignSample(config, 2, 1024, True) + """ + + def __init__(self, config, batch_size, num_bboxes, add_gt_as_proposals): + super(BboxAssignSample, self).__init__() + cfg = config + self.batch_size = batch_size + + self.neg_iou_thr = Tensor(cfg.neg_iou_thr, mstype.float16) + self.pos_iou_thr = Tensor(cfg.pos_iou_thr, mstype.float16) + self.min_pos_iou = Tensor(cfg.min_pos_iou, mstype.float16) + self.zero_thr = Tensor(0.0, mstype.float16) + + self.num_bboxes = num_bboxes + self.num_gts = cfg.num_gts + self.num_expected_pos = cfg.num_expected_pos + self.num_expected_neg = cfg.num_expected_neg + self.add_gt_as_proposals = add_gt_as_proposals + + if self.add_gt_as_proposals: + self.label_inds = Tensor(np.arange(1, self.num_gts + 1)) + + self.concat = P.Concat(axis=0) + self.max_gt = P.ArgMaxWithValue(axis=0) + self.max_anchor = P.ArgMaxWithValue(axis=1) + self.sum_inds = P.ReduceSum() + self.iou = P.IOU() + self.greaterequal = P.GreaterEqual() + self.greater = P.Greater() + self.select = P.Select() + self.gatherND = P.GatherNd() + self.squeeze = P.Squeeze() + self.cast = P.Cast() + self.logicaland = P.LogicalAnd() + self.less = P.Less() + self.random_choice_with_mask_pos = P.RandomChoiceWithMask(self.num_expected_pos) + self.random_choice_with_mask_neg = P.RandomChoiceWithMask(self.num_expected_neg) + self.reshape = P.Reshape() + self.equal = P.Equal() + self.bounding_box_encode = BoundingBoxEncode() + self.scatterNdUpdate = P.ScatterNdUpdate() + self.scatterNd = P.ScatterNd() + self.logicalnot = P.LogicalNot() + self.tile = P.Tile() + self.zeros_like = P.ZerosLike() + + self.assigned_gt_inds = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) + self.assigned_gt_zeros = Tensor(np.array(np.zeros(num_bboxes), dtype=np.int32)) + self.assigned_gt_ones = Tensor(np.array(np.ones(num_bboxes), dtype=np.int32)) + self.assigned_gt_ignores = Tensor(np.array(-1 * np.ones(num_bboxes), dtype=np.int32)) + self.assigned_pos_ones = Tensor(np.array(np.ones(self.num_expected_pos), dtype=np.int32)) + + self.check_neg_mask = Tensor(np.array(np.ones(self.num_expected_neg - self.num_expected_pos), dtype=np.bool)) + self.range_pos_size = Tensor(np.arange(self.num_expected_pos).astype(np.float16)) + self.check_gt_one = Tensor(np.array(-1 * np.ones((self.num_gts, 4)), dtype=np.float16)) + self.check_anchor_two = Tensor(np.array(-2 * np.ones((self.num_bboxes, 4)), dtype=np.float16)) + self.print = P.Print() + + + def construct(self, gt_bboxes_i, gt_labels_i, valid_mask, bboxes, gt_valids): + gt_bboxes_i = self.select(self.cast(self.tile(self.reshape(self.cast(gt_valids, mstype.int32), \ + (self.num_gts, 1)), (1, 4)), mstype.bool_), gt_bboxes_i, self.check_gt_one) + bboxes = self.select(self.cast(self.tile(self.reshape(self.cast(valid_mask, mstype.int32), \ + (self.num_bboxes, 1)), (1, 4)), mstype.bool_), bboxes, self.check_anchor_two) + overlaps = self.iou(bboxes, gt_bboxes_i) + max_overlaps_w_gt_index, max_overlaps_w_gt = self.max_gt(overlaps) + _, max_overlaps_w_ac = self.max_anchor(overlaps) + neg_sample_iou_mask = self.logicaland(self.greaterequal(max_overlaps_w_gt, self.zero_thr), \ + self.less(max_overlaps_w_gt, self.neg_iou_thr)) + assigned_gt_inds2 = self.select(neg_sample_iou_mask, self.assigned_gt_zeros, self.assigned_gt_inds) + pos_sample_iou_mask = self.greaterequal(max_overlaps_w_gt, self.pos_iou_thr) + assigned_gt_inds3 = self.select(pos_sample_iou_mask, \ + max_overlaps_w_gt_index + self.assigned_gt_ones, assigned_gt_inds2) + assigned_gt_inds4 = assigned_gt_inds3 + for j in range(self.num_gts): + max_overlaps_w_ac_j = max_overlaps_w_ac[j:j+1:1] + overlaps_w_gt_j = self.squeeze(overlaps[j:j+1:1, ::]) + + pos_mask_j = self.logicaland(self.greaterequal(max_overlaps_w_ac_j, self.min_pos_iou), \ + self.equal(overlaps_w_gt_j, max_overlaps_w_ac_j)) + assigned_gt_inds4 = self.select(pos_mask_j, self.assigned_gt_ones + j, assigned_gt_inds4) + assigned_gt_inds5 = self.select(valid_mask, assigned_gt_inds4, self.assigned_gt_ignores) + pos_index, valid_pos_index = self.random_choice_with_mask_pos(self.greater(assigned_gt_inds5, 0)) + pos_check_valid = self.cast(self.greater(assigned_gt_inds5, 0), mstype.float16) + pos_check_valid = self.sum_inds(pos_check_valid, -1) + valid_pos_index = self.less(self.range_pos_size, pos_check_valid) + pos_index = pos_index * self.reshape(self.cast(valid_pos_index, mstype.int32), (self.num_expected_pos, 1)) + pos_assigned_gt_index = self.gatherND(assigned_gt_inds5, pos_index) - self.assigned_pos_ones + pos_assigned_gt_index = pos_assigned_gt_index * self.cast(valid_pos_index, mstype.int32) + pos_assigned_gt_index = self.reshape(pos_assigned_gt_index, (self.num_expected_pos, 1)) + neg_index, valid_neg_index = self.random_choice_with_mask_neg(self.equal(assigned_gt_inds5, 0)) + + num_pos = self.cast(self.logicalnot(valid_pos_index), mstype.float16) + num_pos = self.sum_inds(num_pos, -1) + unvalid_pos_index = self.less(self.range_pos_size, num_pos) + valid_neg_index = self.logicaland(self.concat((self.check_neg_mask, unvalid_pos_index)), valid_neg_index) + pos_bboxes_ = self.gatherND(bboxes, pos_index) + pos_gt_bboxes_ = self.gatherND(gt_bboxes_i, pos_assigned_gt_index) + pos_gt_labels = self.gatherND(gt_labels_i, pos_assigned_gt_index) + pos_bbox_targets_ = self.bounding_box_encode(pos_bboxes_, pos_gt_bboxes_) + valid_pos_index = self.cast(valid_pos_index, mstype.int32) + valid_neg_index = self.cast(valid_neg_index, mstype.int32) + bbox_targets_total = self.scatterNd(pos_index, pos_bbox_targets_, (self.num_bboxes, 4)) + bbox_weights_total = self.scatterNd(pos_index, valid_pos_index, (self.num_bboxes,)) + labels_total = self.scatterNd(pos_index, pos_gt_labels, (self.num_bboxes,)) + total_index = self.concat((pos_index, neg_index)) + total_valid_index = self.concat((valid_pos_index, valid_neg_index)) + label_weights_total = self.scatterNd(total_index, total_valid_index, (self.num_bboxes,)) + return bbox_targets_total, self.cast(bbox_weights_total, mstype.bool_), \ + labels_total, self.cast(label_weights_total, mstype.bool_) diff --git a/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py b/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py new file mode 100644 index 0000000000..34b187fbf7 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py @@ -0,0 +1,190 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FasterRcnn proposal generator.""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore import Tensor +from src.CTPN.BoundingBoxDecode import BoundingBoxDecode + +class Proposal(nn.Cell): + """ + Proposal subnet. + + Args: + config (dict): Config. + batch_size (int): Batchsize. + num_classes (int) - Class number. + use_sigmoid_cls (bool) - Select sigmoid or softmax function. + target_means (tuple) - Means for encode function. Default: (.0, .0, .0, .0). + target_stds (tuple) - Stds for encode function. Default: (1.0, 1.0, 1.0, 1.0). + + Returns: + Tuple, tuple of output tensor,(proposal, mask). + + Examples: + Proposal(config = config, batch_size = 1, num_classes = 81, use_sigmoid_cls = True, \ + target_means=(.0, .0, .0, .0), target_stds=(1.0, 1.0, 1.0, 1.0)) + """ + def __init__(self, + config, + batch_size, + num_classes, + use_sigmoid_cls, + target_means=(.0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0) + ): + super(Proposal, self).__init__() + cfg = config + self.batch_size = batch_size + self.num_classes = num_classes + self.target_means = target_means + self.target_stds = target_stds + self.use_sigmoid_cls = config.use_sigmoid_cls + + if self.use_sigmoid_cls: + self.cls_out_channels = 1 + self.activation = P.Sigmoid() + self.reshape_shape = (-1, 1) + else: + self.cls_out_channels = num_classes + self.activation = P.Softmax(axis=1) + self.reshape_shape = (-1, 2) + + if self.cls_out_channels <= 0: + raise ValueError('num_classes={} is too small'.format(num_classes)) + + self.num_pre = cfg.rpn_proposal_nms_pre + self.min_box_size = cfg.rpn_proposal_min_bbox_size + self.nms_thr = cfg.rpn_proposal_nms_thr + self.nms_post = cfg.rpn_proposal_nms_post + self.nms_across_levels = cfg.rpn_proposal_nms_across_levels + self.max_num = cfg.rpn_proposal_max_num + + # Op Define + self.squeeze = P.Squeeze() + self.reshape = P.Reshape() + self.cast = P.Cast() + + self.feature_shapes = cfg.feature_shapes + + self.transpose_shape = (1, 2, 0) + + self.decode = BoundingBoxDecode() + + self.nms = P.NMSWithMask(self.nms_thr) + self.concat_axis0 = P.Concat(axis=0) + self.concat_axis1 = P.Concat(axis=1) + self.split = P.Split(axis=1, output_num=5) + self.min = P.Minimum() + self.gatherND = P.GatherNd() + self.slice = P.Slice() + self.select = P.Select() + self.greater = P.Greater() + self.transpose = P.Transpose() + self.tile = P.Tile() + self.set_train_local(config, training=True) + + self.multi_10 = Tensor(10.0, mstype.float16) + + def set_train_local(self, config, training=False): + """Set training flag.""" + self.training_local = training + cfg = config + self.topK_stage1 = () + self.topK_shape = () + total_max_topk_input = 0 + if not self.training_local: + self.num_pre = cfg.rpn_nms_pre + self.min_box_size = cfg.rpn_min_bbox_min_size + self.nms_thr = cfg.rpn_nms_thr + self.nms_post = cfg.rpn_nms_post + self.max_num = cfg.rpn_max_num + k_num = self.num_pre + total_max_topk_input = k_num + self.topK_stage1 = k_num + self.topK_shape = (k_num, 1) + + self.topKv2 = P.TopK(sorted=True) + self.topK_shape_stage2 = (self.max_num, 1) + self.min_float_num = -65536.0 + self.topK_mask = Tensor(self.min_float_num * np.ones(total_max_topk_input, np.float16)) + self.shape = P.Shape() + self.print = P.Print() + + def construct(self, rpn_cls_score_total, rpn_bbox_pred_total, anchor_list): + proposals_tuple = () + masks_tuple = () + for img_id in range(self.batch_size): + rpn_cls_score_i = self.squeeze(rpn_cls_score_total[img_id:img_id+1:1, ::, ::, ::]) + rpn_bbox_pred_i = self.squeeze(rpn_bbox_pred_total[img_id:img_id+1:1, ::, ::, ::]) + proposals, masks = self.get_bboxes_single(rpn_cls_score_i, rpn_bbox_pred_i, anchor_list) + proposals_tuple += (proposals,) + masks_tuple += (masks,) + return proposals_tuple, masks_tuple + + def get_bboxes_single(self, cls_scores, bbox_preds, mlvl_anchors): + """Get proposal boundingbox.""" + mlvl_proposals = () + mlvl_mask = () + + rpn_cls_score = self.transpose(cls_scores, self.transpose_shape) + rpn_bbox_pred = self.transpose(bbox_preds, self.transpose_shape) + anchors = mlvl_anchors + + # (H, W, A*2) + rpn_cls_score_shape = self.shape(rpn_cls_score) + rpn_cls_score = self.reshape(rpn_cls_score, (rpn_cls_score_shape[0], \ + rpn_cls_score_shape[1], -1, self.cls_out_channels)) + rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape) + rpn_cls_score = self.activation(rpn_cls_score) + if self.use_sigmoid_cls: + rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score), mstype.float16) + else: + rpn_cls_score_process = self.cast(self.squeeze(rpn_cls_score[::, 1]), mstype.float16) + + rpn_bbox_pred_process = self.cast(self.reshape(rpn_bbox_pred, (-1, 4)), mstype.float16) + + scores_sorted, topk_inds = self.topKv2(rpn_cls_score_process, self.num_pre) + + topk_inds = self.reshape(topk_inds, self.topK_shape) + + bboxes_sorted = self.gatherND(rpn_bbox_pred_process, topk_inds) + anchors_sorted = self.cast(self.gatherND(anchors, topk_inds), mstype.float16) + + proposals_decode = self.decode(anchors_sorted, bboxes_sorted) + + proposals_decode = self.concat_axis1((proposals_decode, self.reshape(scores_sorted, self.topK_shape))) + proposals, _, mask_valid = self.nms(proposals_decode) + + mlvl_proposals = mlvl_proposals + (proposals,) + mlvl_mask = mlvl_mask + (mask_valid,) + + proposals = self.concat_axis0(mlvl_proposals) + masks = self.concat_axis0(mlvl_mask) + + _, _, _, _, scores = self.split(proposals) + scores = self.squeeze(scores) + topk_mask = self.cast(self.topK_mask, mstype.float16) + scores_using = self.select(masks, scores, topk_mask) + + _, topk_inds = self.topKv2(scores_using, self.max_num) + + topk_inds = self.reshape(topk_inds, self.topK_shape_stage2) + proposals = self.gatherND(proposals, topk_inds) + masks = self.gatherND(masks, topk_inds) + return proposals, masks diff --git a/model_zoo/official/cv/ctpn/src/CTPN/rpn.py b/model_zoo/official/cv/ctpn/src/CTPN/rpn.py new file mode 100644 index 0000000000..46826c66fe --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/rpn.py @@ -0,0 +1,228 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""RPN for fasterRCNN""" +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.ops import functional as F +from src.CTPN.bbox_assign_sample import BboxAssignSample + +class RpnRegClsBlock(nn.Cell): + """ + Rpn reg cls block for rpn layer + + Args: + config(EasyDict) - Network construction config. + in_channels (int) - Input channels of shared convolution. + feat_channels (int) - Output channels of shared convolution. + num_anchors (int) - The anchor number. + cls_out_channels (int) - Output channels of classification convolution. + + Returns: + Tensor, output tensor. + """ + + def __init__(self, + config, + in_channels, + feat_channels, + num_anchors, + cls_out_channels): + super(RpnRegClsBlock, self).__init__() + self.shape = P.Shape() + self.reshape = P.Reshape() + self.shape = (-1, 2*config.hidden_size) + self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16) + self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16) + self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16) + self.shape1 = (config.num_step, config.rnn_batch_size, -1) + self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step) + self.transpose = P.Transpose() + self.print = P.Print() + self.dropout = nn.Dropout(0.8) + + def construct(self, x): + x = self.reshape(x, self.shape) + x = self.lstm_fc(x) + x1 = self.rpn_cls(x) + x1 = self.reshape(x1, self.shape1) + x1 = self.transpose(x1, (2, 1, 0)) + x1 = self.reshape(x1, self.shape2) + x1 = self.transpose(x1, (1, 0, 2, 3)) + x2 = self.rpn_reg(x) + x2 = self.reshape(x2, self.shape1) + x2 = self.transpose(x2, (2, 1, 0)) + x2 = self.reshape(x2, self.shape2) + x2 = self.transpose(x2, (1, 0, 2, 3)) + return x1, x2 + +class RPN(nn.Cell): + """ + ROI proposal network.. + + Args: + config (dict) - Config. + batch_size (int) - Batchsize. + in_channels (int) - Input channels of shared convolution. + feat_channels (int) - Output channels of shared convolution. + num_anchors (int) - The anchor number. + cls_out_channels (int) - Output channels of classification convolution. + + Returns: + Tuple, tuple of output tensor. + + Examples: + RPN(config=config, batch_size=2, in_channels=256, feat_channels=1024, + num_anchors=3, cls_out_channels=512) + """ + def __init__(self, + config, + batch_size, + in_channels, + feat_channels, + num_anchors, + cls_out_channels): + super(RPN, self).__init__() + cfg_rpn = config + self.cfg = config + self.num_bboxes = cfg_rpn.num_bboxes + self.feature_anchor_shape = cfg_rpn.feature_shapes + self.feature_anchor_shape = self.feature_anchor_shape[0] * \ + self.feature_anchor_shape[1] * num_anchors * batch_size + self.num_anchors = num_anchors + self.batch_size = batch_size + self.test_batch_size = cfg_rpn.test_batch_size + self.num_layers = 1 + self.real_ratio = Tensor(np.ones((1, 1)).astype(np.float16)) + self.use_sigmoid_cls = config.use_sigmoid_cls + if config.use_sigmoid_cls: + self.reshape_shape_cls = (-1,) + self.loss_cls = P.SigmoidCrossEntropyWithLogits() + cls_out_channels = 1 + else: + self.reshape_shape_cls = (-1, cls_out_channels) + self.loss_cls = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="none") + self.rpn_convs_list = self._make_rpn_layer(self.num_layers, in_channels, feat_channels,\ + num_anchors, cls_out_channels) + + self.transpose = P.Transpose() + self.reshape = P.Reshape() + self.concat = P.Concat(axis=0) + self.fill = P.Fill() + self.placeh1 = Tensor(np.ones((1,)).astype(np.float16)) + + self.trans_shape = (0, 2, 3, 1) + + self.reshape_shape_reg = (-1, 4) + self.softmax = nn.Softmax() + self.rpn_loss_reg_weight = Tensor(np.array(cfg_rpn.rpn_loss_reg_weight).astype(np.float16)) + self.rpn_loss_cls_weight = Tensor(np.array(cfg_rpn.rpn_loss_cls_weight).astype(np.float16)) + self.num_expected_total = Tensor(np.array(cfg_rpn.num_expected_neg * self.batch_size).astype(np.float16)) + self.num_bboxes = cfg_rpn.num_bboxes + self.get_targets = BboxAssignSample(cfg_rpn, self.batch_size, self.num_bboxes, False) + self.CheckValid = P.CheckValid() + self.sum_loss = P.ReduceSum() + self.loss_bbox = P.SmoothL1Loss(beta=1.0/9.0) + self.squeeze = P.Squeeze() + self.cast = P.Cast() + self.tile = P.Tile() + self.zeros_like = P.ZerosLike() + self.loss = Tensor(np.zeros((1,)).astype(np.float16)) + self.clsloss = Tensor(np.zeros((1,)).astype(np.float16)) + self.regloss = Tensor(np.zeros((1,)).astype(np.float16)) + self.print = P.Print() + + def _make_rpn_layer(self, num_layers, in_channels, feat_channels, num_anchors, cls_out_channels): + """ + make rpn layer for rpn proposal network + + Args: + num_layers (int) - layer num. + in_channels (int) - Input channels of shared convolution. + feat_channels (int) - Output channels of shared convolution. + num_anchors (int) - The anchor number. + cls_out_channels (int) - Output channels of classification convolution. + + Returns: + List, list of RpnRegClsBlock cells. + """ + rpn_layer = RpnRegClsBlock(self.cfg, in_channels, feat_channels, num_anchors, cls_out_channels) + return rpn_layer + + def construct(self, inputs, img_metas, anchor_list, gt_bboxes, gt_labels, gt_valids): + ''' + inputs(Tensor): Inputs tensor from lstm. + img_metas(Tensor): Image shape. + anchor_list(Tensor): Total anchor list. + gt_labels(Tensor): Ground truth labels. + gt_valids(Tensor): Whether ground truth is valid. + ''' + rpn_cls_score_ori, rpn_bbox_pred_ori = self.rpn_convs_list(inputs) + rpn_cls_score = self.transpose(rpn_cls_score_ori, self.trans_shape) + rpn_cls_score = self.reshape(rpn_cls_score, self.reshape_shape_cls) + rpn_bbox_pred = self.transpose(rpn_bbox_pred_ori, self.trans_shape) + rpn_bbox_pred = self.reshape(rpn_bbox_pred, self.reshape_shape_reg) + output = () + bbox_targets = () + bbox_weights = () + labels = () + label_weights = () + if self.training: + for i in range(self.batch_size): + valid_flag_list = self.cast(self.CheckValid(anchor_list, self.squeeze(img_metas[i:i + 1:1, ::])),\ + mstype.int32) + gt_bboxes_i = self.squeeze(gt_bboxes[i:i + 1:1, ::]) + gt_labels_i = self.squeeze(gt_labels[i:i + 1:1, ::]) + gt_valids_i = self.squeeze(gt_valids[i:i + 1:1, ::]) + bbox_target, bbox_weight, label, label_weight = self.get_targets(gt_bboxes_i, + gt_labels_i, + self.cast(valid_flag_list, + mstype.bool_), + anchor_list, gt_valids_i) + bbox_weight = self.cast(bbox_weight, mstype.float16) + label_weight = self.cast(label_weight, mstype.float16) + bbox_targets += (bbox_target,) + bbox_weights += (bbox_weight,) + labels += (label,) + label_weights += (label_weight,) + bbox_target_with_batchsize = self.concat(bbox_targets) + bbox_weight_with_batchsize = self.concat(bbox_weights) + label_with_batchsize = self.concat(labels) + label_weight_with_batchsize = self.concat(label_weights) + + bbox_target_ = F.stop_gradient(bbox_target_with_batchsize) + bbox_weight_ = F.stop_gradient(bbox_weight_with_batchsize) + label_ = F.stop_gradient(label_with_batchsize) + label_weight_ = F.stop_gradient(label_weight_with_batchsize) + rpn_cls_score = self.cast(rpn_cls_score, mstype.float32) + if self.use_sigmoid_cls: + label_ = self.cast(label_, mstype.float32) + loss_cls = self.loss_cls(rpn_cls_score, label_) + loss_cls = loss_cls * label_weight_ + loss_cls = self.sum_loss(loss_cls, (0,)) / self.num_expected_total + rpn_bbox_pred = self.cast(rpn_bbox_pred, mstype.float32) + bbox_target_ = self.cast(bbox_target_, mstype.float32) + loss_reg = self.loss_bbox(rpn_bbox_pred, bbox_target_) + bbox_weight_ = self.tile(self.reshape(bbox_weight_, (self.feature_anchor_shape, 1)), (1, 4)) + loss_reg = loss_reg * bbox_weight_ + loss_reg = self.sum_loss(loss_reg, (1,)) + loss_reg = self.sum_loss(loss_reg, (0,)) / self.num_expected_total + loss_total = self.rpn_loss_cls_weight * loss_cls + self.rpn_loss_reg_weight * loss_reg + output = (loss_total, rpn_cls_score_ori, rpn_bbox_pred_ori, loss_cls, loss_reg) + else: + output = (self.placeh1, rpn_cls_score_ori, rpn_bbox_pred_ori, self.placeh1, self.placeh1) + return output diff --git a/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py b/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py new file mode 100644 index 0000000000..0d06e68b95 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/CTPN/vgg16.py @@ -0,0 +1,177 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype + +def _weight_variable(shape, factor=0.01): + ''''weight initialize''' + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + +def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False): + """Batchnorm2D wrapper.""" + gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) + beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) + moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) + moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) + + return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, + beta_init=beta_init, moving_mean_init=moving_mean_init, + moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) + +def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True): + """Conv2D wrapper.""" + weights = 'ones' + layers = [] + conv = nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, + pad_mode=pad_mode, weight_init=weights, has_bias=False) + if not weights_update: + conv.weight.requires_grad = False + layers += [conv] + layers += [_BatchNorm2dInit(out_channels)] + return nn.SequentialCell(layers) + + +def _fc(in_channels, out_channels): + '''full connection layer''' + weight = _weight_variable((out_channels, in_channels)) + bias = _weight_variable((out_channels,)) + return nn.Dense(in_channels, out_channels, weight, bias) + + +class VGG16FeatureExtraction(nn.Cell): + def __init__(self, weights_update=False): + """ + VGG16 feature extraction + + Args: + weights_updata(bool): whether update weights for top two layers, default is False. + """ + super(VGG16FeatureExtraction, self).__init__() + self.relu = nn.ReLU() + self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") + self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) + + self.conv1_1 = _conv(in_channels=3, out_channels=64, kernel_size=3,\ + padding=1, weights_update=weights_update) + self.conv1_2 = _conv(in_channels=64, out_channels=64, kernel_size=3,\ + padding=1, weights_update=weights_update) + + self.conv2_1 = _conv(in_channels=64, out_channels=128, kernel_size=3,\ + padding=1, weights_update=weights_update) + self.conv2_2 = _conv(in_channels=128, out_channels=128, kernel_size=3,\ + padding=1, weights_update=weights_update) + + self.conv3_1 = _conv(in_channels=128, out_channels=256, kernel_size=3, padding=1) + self.conv3_2 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1) + self.conv3_3 = _conv(in_channels=256, out_channels=256, kernel_size=3, padding=1) + + self.conv4_1 = _conv(in_channels=256, out_channels=512, kernel_size=3, padding=1) + self.conv4_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) + self.conv4_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) + + self.conv5_1 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) + self.conv5_2 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) + self.conv5_3 = _conv(in_channels=512, out_channels=512, kernel_size=3, padding=1) + self.cast = P.Cast() + + def construct(self, x): + """ + :param x: shape=(B, 3, 224, 224) + :return: + """ + x = self.cast(x, mstype.float32) + x = self.conv1_1(x) + x = self.relu(x) + x = self.conv1_2(x) + x = self.relu(x) + x = self.max_pool(x) + + x = self.conv2_1(x) + x = self.relu(x) + x = self.conv2_2(x) + x = self.relu(x) + x = self.max_pool(x) + + x = self.conv3_1(x) + x = self.relu(x) + x = self.conv3_2(x) + x = self.relu(x) + x = self.conv3_3(x) + x = self.relu(x) + x = self.max_pool(x) + + x = self.conv4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.relu(x) + x = self.conv4_3(x) + x = self.relu(x) + x = self.max_pool(x) + + x = self.conv5_1(x) + x = self.relu(x) + x = self.conv5_2(x) + x = self.relu(x) + x = self.conv5_3(x) + x = self.relu(x) + return x + +class VGG16Classfier(nn.Cell): + def __init__(self): + """VGG16 classfier structure""" + super(VGG16Classfier, self).__init__() + self.flatten = P.Flatten() + self.relu = nn.ReLU() + self.fc1 = _fc(in_channels=7*7*512, out_channels=4096) + self.fc2 = _fc(in_channels=4096, out_channels=4096) + self.batch_size = 32 + self.reshape = P.Reshape() + + def construct(self, x): + """ + :param x: shape=(B, 512, 7, 7) + :return: + """ + x = self.reshape(x, (self.batch_size, 7*7*512)) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + return x + +class VGG16(nn.Cell): + def __init__(self): + """VGG16 construct for training backbone""" + super(VGG16, self).__init__() + self.feature_extraction = VGG16FeatureExtraction(weights_update=True) + self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.classifier = VGG16Classfier() + self.fc3 = _fc(in_channels=4096, out_channels=1000) + + def construct(self, x): + """ + :param x: shape=(B, 3, 224, 224) + :return: logits, shape=(B, 1000) + """ + feature_maps = self.feature_extraction(x) + x = self.max_pool(feature_maps) + x = self.classifier(x) + x = self.fc3(x) + return x diff --git a/model_zoo/official/cv/ctpn/src/config.py b/model_zoo/official/cv/ctpn/src/config.py new file mode 100644 index 0000000000..de114b72f5 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/config.py @@ -0,0 +1,133 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Network parameters.""" +from easydict import EasyDict +pretrain_config = EasyDict({ + # LR + "base_lr": 0.0009, + "warmup_step": 30000, + "warmup_ratio": 1/3.0, + "total_epoch": 100, +}) +finetune_config = EasyDict({ + # LR + "base_lr": 0.0005, + "warmup_step": 300, + "warmup_ratio": 1/3.0, + "total_epoch": 50, +}) + +# use for low case number +config = EasyDict({ + "img_width": 960, + "img_height": 576, + "keep_ratio": False, + "flip_ratio": 0.0, + "photo_ratio": 0.0, + "expand_ratio": 1.0, + + # anchor + "feature_shapes": (36, 60), + "num_anchors": 14, + "anchor_base": 16, + "anchor_height": [2, 4, 7, 11, 16, 23, 33, 48, 68, 97, 139, 198, 283, 406], + "anchor_width": [16], + + # rpn + "rpn_in_channels": 256, + "rpn_feat_channels": 512, + "rpn_loss_cls_weight": 1.0, + "rpn_loss_reg_weight": 3.0, + "rpn_cls_out_channels": 2, + + # bbox_assign_sampler + "neg_iou_thr": 0.5, + "pos_iou_thr": 0.7, + "min_pos_iou": 0.001, + "num_bboxes": 30240, + "num_gts": 256, + "num_expected_neg": 512, + "num_expected_pos": 256, + + #proposal + "activate_num_classes": 2, + "use_sigmoid_cls": False, + + # train proposal + "rpn_proposal_nms_across_levels": False, + "rpn_proposal_nms_pre": 2000, + "rpn_proposal_nms_post": 1000, + "rpn_proposal_max_num": 1000, + "rpn_proposal_nms_thr": 0.7, + "rpn_proposal_min_bbox_size": 8, + + # rnn structure + "input_size": 512, + "num_step": 60, + "rnn_batch_size": 36, + "hidden_size": 128, + + # training + "warmup_mode": "linear", + "batch_size": 1, + "momentum": 0.9, + "save_checkpoint": True, + "save_checkpoint_epochs": 10, + "keep_checkpoint_max": 5, + "save_checkpoint_path": "./", + "use_dropout": False, + "loss_scale": 1, + "weight_decay": 1e-4, + + # test proposal + "rpn_nms_pre": 2000, + "rpn_nms_post": 1000, + "rpn_max_num": 1000, + "rpn_nms_thr": 0.7, + "rpn_min_bbox_min_size": 8, + "test_iou_thr": 0.7, + "test_max_per_img": 100, + "test_batch_size": 1, + "use_python_proposal": False, + + # text proposal connection + "max_horizontal_gap": 60, + "text_proposals_min_scores": 0.7, + "text_proposals_nms_thresh": 0.2, + "min_v_overlaps": 0.7, + "min_size_sim": 0.7, + "min_ratio": 0.5, + "line_min_score": 0.9, + "text_proposals_width": 16, + "min_num_proposals": 2, + + # create dataset + "coco_root": "", + "coco_train_data_type": "", + "cocotext_json": "", + "icdar11_train_path": [], + "icdar13_train_path": [], + "icdar15_train_path": [], + "icdar13_test_path": [], + "flick_train_path": [], + "svt_train_path": [], + "pretrain_dataset_path": "", + "finetune_dataset_path": "", + "test_dataset_path": "", + + # training dataset + "pretraining_dataset_file": "", + "finetune_dataset_file": "" +}) diff --git a/model_zoo/official/cv/ctpn/src/convert_icdar2015.py b/model_zoo/official/cv/ctpn/src/convert_icdar2015.py new file mode 100644 index 0000000000..f2504ba948 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/convert_icdar2015.py @@ -0,0 +1,61 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""convert icdar2015 dataset label""" +import os +import argparse +def init_args(): + parser = argparse.ArgumentParser('') + parser.add_argument('-s', '--src_label_path', type=str, default='./', + help='Directory containing icdar2015 train label') + parser.add_argument('-t', '--target_label_path', type=str, default='test.xml', + help='Directory where save the icdar2015 label after convert') + return parser.parse_args() + +def convert(): + args = init_args() + anno_file = os.listdir(args.src_label_path) + annos = {} + # read + for file in anno_file: + gt = open(os.path.join(args.src_label_path, file), 'r', encoding='UTF-8-sig').read().splitlines() + label_list = [] + label_name = os.path.basename(file) + for each_label in gt: + print(file) + spt = each_label.split(',') + print(spt) + if "###" in spt[8]: + continue + else: + x1 = min(int(spt[0]), int(spt[6])) + y1 = min(int(spt[1]), int(spt[3])) + x2 = max(int(spt[2]), int(spt[4])) + y2 = max(int(spt[5]), int(spt[7])) + label_list.append([x1, y1, x2, y2]) + annos[label_name] = label_list + # write + if not os.path.exists(args.target_label_path): + os.makedirs(args.target_label_path) + for label_file, pos in annos.items(): + tgt_anno_file = os.path.join(args.target_label_path, label_file) + f = open(tgt_anno_file, 'w', encoding='UTF-8-sig') + for tgt_label in pos: + str_pos = str(tgt_label[0]) + ',' + str(tgt_label[1]) + ',' + str(tgt_label[2]) + ',' + str(tgt_label[3]) + f.write(str_pos) + f.write("\n") + f.close() + +if __name__ == "__main__": + convert() diff --git a/model_zoo/official/cv/ctpn/src/convert_svt.py b/model_zoo/official/cv/ctpn/src/convert_svt.py new file mode 100644 index 0000000000..5a767ed63c --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/convert_svt.py @@ -0,0 +1,94 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""convert svt dataset label""" +import os +import argparse +from xml.etree import ElementTree as ET +import numpy as np + +def init_args(): + parser = argparse.ArgumentParser('') + parser.add_argument('-d', '--dataset_dir', type=str, default='./', + help='Directory containing images') + parser.add_argument('-x', '--xml_file', type=str, default='test.xml', + help='Directory where character dictionaries for the dataset were stored') + parser.add_argument('-o', '--location_dir', type=str, default='./location', + help='Directory where ord map dictionaries for the dataset were stored') + return parser.parse_args() + +def xml_to_dict(xml_file, save_file=False): + tree = ET.parse(xml_file) + root = tree.getroot() + imgs_labels = [] + + for ch in root: + im_label = {} + for ch01 in ch: + if ch01.tag in "address": + continue + elif ch01.tag in 'taggedRectangles': + # multiple children + rect_list = [] + for ch02 in ch01: + rect = {} + rect['location'] = ch02.attrib + rect['label'] = ch02[0].text + rect_list.append(rect) + im_label['rect'] = rect_list + else: + im_label[ch01.tag] = ch01.text + imgs_labels.append(im_label) + + if save_file: + np.save("annotation_train.npy", imgs_labels) + + return imgs_labels + +def convert(): + args = init_args() + if not os.path.exists(args.dataset_dir): + raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir)) + + if not os.path.exists(args.xml_file): + raise ValueError("xml_file :{} does not exist".format(args.xml_file)) + + if not os.path.exists(args.location_dir): + os.makedirs(args.location_dir) + + ims_labels_dict = xml_to_dict(args.xml_file, True) + num_images = len(ims_labels_dict) + print("Converting annotation, {} images in total ".format(num_images)) + for i in range(num_images): + img_label = ims_labels_dict[i] + image_name = img_label['imageName'] + rects = img_label['rect'] + print("processing image: {}".format(image_name)) + location_file_name = os.path.join(args.location_dir, os.path.basename(image_name).replace(".jpg", ".txt")) + f = open(location_file_name, 'w') + for j, rect in enumerate(rects): + rect = rects[j] + location = rect['location'] + h = int(location['height']) + w = int(location['width']) + x = int(location['x']) + y = int(location['y']) + pos = [x, y, x+w, y+h] + str_pos = str(pos[0]) + "," + str(pos[1]) + "," + str(pos[2]) + "," + str(pos[3]) + f.write(str_pos) + f.write("\n") + f.close() + +if __name__ == "__main__": + convert() diff --git a/model_zoo/official/cv/ctpn/src/create_dataset.py b/model_zoo/official/cv/ctpn/src/create_dataset.py new file mode 100644 index 0000000000..ef9a8faf2c --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/create_dataset.py @@ -0,0 +1,177 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from __future__ import division +import os +import numpy as np +from PIL import Image +from mindspore.mindrecord import FileWriter +from src.config import config + +def create_coco_label(): + """Create image label.""" + image_files = [] + image_anno_dict = {} + coco_root = config.coco_root + data_type = config.coco_train_data_type + from src.coco_text import COCO_Text + anno_json = config.cocotext_json + ct = COCO_Text(anno_json) + image_ids = ct.getImgIds(imgIds=ct.train, + catIds=[('legibility', 'legible')]) + for img_id in image_ids: + image_info = ct.loadImgs(img_id)[0] + file_name = image_info['file_name'][15:] + anno_ids = ct.getAnnIds(imgIds=img_id) + anno = ct.loadAnns(anno_ids) + image_path = os.path.join(coco_root, data_type, file_name) + annos = [] + im = Image.open(image_path) + width, _ = im.size + for label in anno: + bbox = label["bbox"] + bbox_width = int(bbox[2]) + if 60 * bbox_width < width: + continue + x1, x2 = int(bbox[0]), int(bbox[0] + bbox[2]) + y1, y2 = int(bbox[1]), int(bbox[1] + bbox[3]) + annos.append([x1, y1, x2, y2] + [1]) + if annos: + image_anno_dict[image_path] = np.array(annos) + image_files.append(image_path) + return image_files, image_anno_dict + +def create_anno_dataset_label(train_img_dirs, train_txt_dirs): + image_files = [] + image_anno_dict = {} + # read + img_basenames = [] + for file in os.listdir(train_img_dirs): + # Filter git file. + if 'gif' not in file: + img_basenames.append(os.path.basename(file)) + img_names = [] + for item in img_basenames: + temp1, _ = os.path.splitext(item) + img_names.append((temp1, item)) + for img, img_basename in img_names: + image_path = train_img_dirs + '/' + img_basename + annos = [] + if len(img) == 6 and '_' not in img_basename: + gt = open(train_txt_dirs + '/' + img + '.txt').read().splitlines() + if img.isdigit() and int(img) > 1200: + continue + for img_each_label in gt: + spt = img_each_label.replace(',', '').split(' ') + if ' ' not in img_each_label: + spt = img_each_label.split(',') + annos.append([spt[0], spt[1], str(int(spt[0]) + int(spt[2])), str(int(spt[1]) + int(spt[3]))] + [1]) + if annos: + image_anno_dict[image_path] = np.array(annos) + image_files.append(image_path) + return image_files, image_anno_dict + +def create_icdar_svt_label(train_img_dir, train_txt_dir, prefix): + image_files = [] + image_anno_dict = {} + img_basenames = [] + for file_name in os.listdir(train_img_dir): + if 'gif' not in file_name: + img_basenames.append(os.path.basename(file_name)) + img_names = [] + for item in img_basenames: + temp1, _ = os.path.splitext(item) + img_names.append((temp1, item)) + for img, img_basename in img_names: + image_path = train_img_dir + '/' + img_basename + annos = [] + file_name = prefix + img + ".txt" + file_path = os.path.join(train_txt_dir, file_name) + gt = open(file_path, 'r', encoding='UTF-8-sig').read().splitlines() + if not gt: + continue + for img_each_label in gt: + spt = img_each_label.replace(',', '').split(' ') + if ' ' not in img_each_label: + spt = img_each_label.split(',') + annos.append([spt[0], spt[1], spt[2], spt[3]] + [1]) + if annos: + image_anno_dict[image_path] = np.array(annos) + image_files.append(image_path) + return image_files, image_anno_dict + +def create_train_dataset(dataset_type): + image_files = [] + image_anno_dict = {} + if dataset_type == "pretraining": + # pretrianing: coco, flick, icdar2013 train, icdar2015, svt + coco_image_files, coco_anno_dict = create_coco_label() + flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0], + config.flick_train_path[1]) + icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0], + config.icdar13_train_path[1], "gt_img_") + icdar15_image_files, icdar15_anno_dict = create_icdar_svt_label(config.icdar15_train_path[0], + config.icdar15_train_path[1], "gt_") + svt_image_files, svt_anno_dict = create_icdar_svt_label(config.svt_train_path[0], config.svt_train_path[1], "") + image_files = coco_image_files + flick_image_files + icdar13_image_files + icdar15_image_files + svt_image_files + image_anno_dict = {**coco_anno_dict, **flick_anno_dict, \ + **icdar13_anno_dict, **icdar15_anno_dict, **svt_anno_dict} + data_to_mindrecord_byte_image(image_files, image_anno_dict, config.pretrain_dataset_path, \ + prefix="ctpn_pretrain.mindrecord", file_num=8) + elif dataset_type == "finetune": + # finetune: icdar2011, icdar2013 train, flick + flick_image_files, flick_anno_dict = create_anno_dataset_label(config.flick_train_path[0], + config.flick_train_path[1]) + icdar11_image_files, icdar11_anno_dict = create_icdar_svt_label(config.icdar11_train_path[0], + config.icdar11_train_path[1], "gt_") + icdar13_image_files, icdar13_anno_dict = create_icdar_svt_label(config.icdar13_train_path[0], + config.icdar13_train_path[1], "gt_img_") + image_files = flick_image_files + icdar11_image_files + icdar13_image_files + image_anno_dict = {**flick_anno_dict, **icdar11_anno_dict, **icdar13_anno_dict} + data_to_mindrecord_byte_image(image_files, image_anno_dict, config.finetune_dataset_path, \ + prefix="ctpn_finetune.mindrecord", file_num=8) + elif dataset_type == "test": + # test: icdar2013 test + icdar_test_image_files, icdar_test_anno_dict = create_icdar_svt_label(config.icdar13_test_path[0],\ + config.icdar13_test_path[1], "") + image_files = icdar_test_image_files + image_anno_dict = icdar_test_anno_dict + data_to_mindrecord_byte_image(image_files, image_anno_dict, config.test_dataset_path, \ + prefix="ctpn_test.mindrecord", file_num=1) + else: + print("dataset_type should be pretraining, finetune, test") + +def data_to_mindrecord_byte_image(image_files, image_anno_dict, dst_dir, prefix="cptn_mlt.mindrecord", file_num=1): + """Create MindRecord file.""" + mindrecord_path = os.path.join(dst_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + + ctpn_json = { + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(ctpn_json, "ctpn_json") + for image_name in image_files: + with open(image_name, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[image_name], dtype=np.int32) + print("img name is {}, anno is {}".format(image_name, annos)) + row = {"image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + +if __name__ == "__main__": + create_train_dataset("pretraining") + create_train_dataset("finetune") + create_train_dataset("test") diff --git a/model_zoo/official/cv/ctpn/src/ctpn.py b/model_zoo/official/cv/ctpn/src/ctpn.py new file mode 100644 index 0000000000..3656a62c99 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/ctpn.py @@ -0,0 +1,148 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CPTN network definition.""" + +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from src.CTPN.rpn import RPN +from src.CTPN.anchor_generator import AnchorGenerator +from src.CTPN.proposal_generator import Proposal +from src.CTPN.vgg16 import VGG16FeatureExtraction + +class BiLSTM(nn.Cell): + """ + Define a BiLSTM network which contains two LSTM layers + + Args: + input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for + captcha images. + batch_size(int): batch size of input data, default is 64 + hidden_size(int): the hidden size in LSTM layers, default is 512 + """ + def __init__(self, config, is_training=True): + super(BiLSTM, self).__init__() + self.is_training = is_training + self.batch_size = config.batch_size * config.rnn_batch_size + print("batch size is {} ".format(self.batch_size)) + self.input_size = config.input_size + self.hidden_size = config.hidden_size + self.num_step = config.num_step + self.reshape = P.Reshape() + self.cast = P.Cast() + k = (1 / self.hidden_size) ** 0.5 + self.rnn1 = P.DynamicRNN(forget_bias=0.0) + self.rnn_bw = P.DynamicRNN(forget_bias=0.0) + self.w1 = Parameter(np.random.uniform(-k, k, \ + (self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1") + self.w1_bw = Parameter(np.random.uniform(-k, k, \ + (self.input_size + self.hidden_size, 4 * self.hidden_size)).astype(np.float32), name="w1_bw") + + self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1") + self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw") + + self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) + self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) + + self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) + self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32)) + self.reverse_seq = P.ReverseV2(axis=[0]) + self.concat = P.Concat() + self.transpose = P.Transpose() + self.concat1 = P.Concat(axis=2) + self.dropout = nn.Dropout(0.7) + self.use_dropout = config.use_dropout + self.reshape = P.Reshape() + self.transpose = P.Transpose() + def construct(self, x): + if self.use_dropout: + x = self.dropout(x) + x = self.cast(x, mstype.float16) + bw_x = self.reverse_seq(x) + y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1) + y1_bw, _, _, _, _, _, _, _ = self.rnn_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw) + y1_bw = self.reverse_seq(y1_bw) + output = self.concat1((y1, y1_bw)) + return output + +class CTPN(nn.Cell): + """ + Define CTPN network + + Args: + input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for + captcha images. + batch_size(int): batch size of input data, default is 64 + hidden_size(int): the hidden size in LSTM layers, default is 512 + """ + def __init__(self, config, is_training=True): + super(CTPN, self).__init__() + self.config = config + self.is_training = is_training + self.num_step = config.num_step + self.input_size = config.input_size + self.batch_size = config.batch_size + self.hidden_size = config.hidden_size + self.vgg16_feature_extractor = VGG16FeatureExtraction() + self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same') + self.rnn = BiLSTM(self.config, is_training=self.is_training).to_float(mstype.float16) + self.reshape = P.Reshape() + self.transpose = P.Transpose() + self.cast = P.Cast() + + # rpn block + self.rpn_with_loss = RPN(config, + self.batch_size, + config.rpn_in_channels, + config.rpn_feat_channels, + config.num_anchors, + config.rpn_cls_out_channels) + self.anchor_generator = AnchorGenerator(config) + self.featmap_size = config.feature_shapes + self.anchor_list = self.get_anchors(self.featmap_size) + self.proposal_generator_test = Proposal(config, + config.test_batch_size, + config.activate_num_classes, + config.use_sigmoid_cls) + self.proposal_generator_test.set_train_local(config, False) + def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids): + # (1,3,600,900) + x = self.vgg16_feature_extractor(img_data) + x = self.conv(x) + x = self.cast(x, mstype.float16) + # (1, 512, 38, 57) + x = self.transpose(x, (0, 2, 1, 3)) + x = self.reshape(x, (-1, self.input_size, self.num_step)) + x = self.transpose(x, (2, 0, 1)) + # (57, 38, 512) + x = self.rnn(x) + # (57, 38, 256) + #x = self.cast(x, mstype.float32) + rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss = self.rpn_with_loss(x, + img_metas, + self.anchor_list, + gt_bboxes, + gt_labels, + gt_valids) + if self.training: + return rpn_loss, cls_score, bbox_pred, rpn_cls_loss, rpn_reg_loss + proposal, proposal_mask = self.proposal_generator_test(cls_score, bbox_pred, self.anchor_list) + return proposal, proposal_mask + + def get_anchors(self, featmap_size): + anchors = self.anchor_generator.grid_anchors(featmap_size) + return Tensor(anchors, mstype.float16) diff --git a/model_zoo/official/cv/ctpn/src/dataset.py b/model_zoo/official/cv/ctpn/src/dataset.py new file mode 100644 index 0000000000..cebe212b80 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/dataset.py @@ -0,0 +1,342 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FasterRcnn dataset""" +from __future__ import division +import os +import numpy as np +from numpy import random +import mmcv +import mindspore.dataset as de +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as CC +import mindspore.common.dtype as mstype +from mindspore.mindrecord import FileWriter +from src.config import config + +class PhotoMetricDistortion: + """Photo Metric Distortion""" + + def __init__(self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def __call__(self, img, boxes, labels): + img = img.astype('float32') + if random.randint(2): + delta = random.uniform(-self.brightness_delta, self.brightness_delta) + img += delta + mode = random.randint(2) + if mode == 1: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + # convert color from BGR to HSV + img = mmcv.bgr2hsv(img) + # random saturation + if random.randint(2): + img[..., 1] *= random.uniform(self.saturation_lower, + self.saturation_upper) + # random hue + if random.randint(2): + img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + # convert color from HSV to BGR + img = mmcv.hsv2bgr(img) + # random contrast + if mode == 0: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, + self.contrast_upper) + img *= alpha + # randomly swap channels + if random.randint(2): + img = img[..., random.permutation(3)] + return img, boxes, labels + +class Expand: + """expand image""" + + def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): + if to_rgb: + self.mean = mean[::-1] + else: + self.mean = mean + self.min_ratio, self.max_ratio = ratio_range + + def __call__(self, img, boxes, labels): + if random.randint(2): + return img, boxes, labels + h, w, c = img.shape + ratio = random.uniform(self.min_ratio, self.max_ratio) + expand_img = np.full((int(h * ratio), int(w * ratio), c), + self.mean).astype(img.dtype) + left = int(random.uniform(0, w * ratio - w)) + top = int(random.uniform(0, h * ratio - h)) + expand_img[top:top + h, left:left + w] = img + img = expand_img + boxes += np.tile((left, top), 2) + return img, boxes, labels + +def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num): + """rescale operation for image""" + img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) + if img_data.shape[0] > config.img_height: + img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True) + scale_factor = scale_factor * scale_factor2 + img_shape = np.append(img_shape, scale_factor) + img_shape = np.asarray(img_shape, dtype=np.float32) + gt_bboxes = gt_bboxes * scale_factor + gt_bboxes = split_gtbox_label(gt_bboxes) + if gt_bboxes.shape[0] != 0: + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num) + + +def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num): + """resize operation for image""" + img_data = img + img_data, w_scale, h_scale = mmcv.imresize( + img_data, (config.img_width, config.img_height), return_scale=True) + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + img_shape = (config.img_height, config.img_width, 1.0) + img_shape = np.asarray(img_shape, dtype=np.float32) + gt_bboxes = gt_bboxes * scale_factor + gt_bboxes = split_gtbox_label(gt_bboxes) + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num) + + +def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num): + """resize operation for image of eval""" + img_data = img + img_data, w_scale, h_scale = mmcv.imresize( + img_data, (config.img_width, config.img_height), return_scale=True) + scale_factor = np.array( + [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + img_shape = (config.img_height, config.img_width) + img_shape = np.append(img_shape, (h_scale, w_scale)) + img_shape = np.asarray(img_shape, dtype=np.float32) + gt_bboxes = gt_bboxes * scale_factor + shape = gt_bboxes.shape + label_column = np.ones((shape[0], 1), dtype=int) + gt_bboxes = np.concatenate((gt_bboxes, label_column), axis=1) + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num) + +def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num): + """flipped generation""" + img_data = img + flipped = gt_bboxes.copy() + _, w, _ = img_data.shape + flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 + flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 + return (img_data, img_shape, flipped, gt_label, gt_num) + +def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num): + img_data = img[:, :, ::-1] + return (img_data, img_shape, gt_bboxes, gt_label, gt_num) + +def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num): + """photo crop operation for image""" + random_photo = PhotoMetricDistortion() + img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num) + +def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num): + """expand operation for image""" + expand = Expand() + img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label) + + return (img, img_shape, gt_bboxes, gt_label, gt_num) + +def split_gtbox_label(gt_bbox_total): + """split ground truth box label""" + gtbox_list = [] + box_num, _ = gt_bbox_total.shape + for i in range(box_num): + gt_bbox = gt_bbox_total[i] + if gt_bbox[0] % 16 != 0: + gt_bbox[0] = (gt_bbox[0] // 16) * 16 + if gt_bbox[2] % 16 != 0: + gt_bbox[2] = (gt_bbox[2] // 16 + 1) * 16 + x0_array = np.arange(gt_bbox[0], gt_bbox[2], 16) + for x0 in x0_array: + gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1]) + return np.array(gtbox_list) + +def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid): + """pad ground truth label""" + pad_max_number = 256 + gt_label = gt_bboxes[:, 4] + gt_valid = gt_bboxes[:, 4] + if gt_bboxes.shape[0] < 256: + gt_box = np.pad(gt_bboxes, ((0, pad_max_number - gt_bboxes.shape[0]), (0, 0)), \ + mode="constant", constant_values=0) + gt_label = np.pad(gt_label, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=-1) + gt_valid = np.pad(gt_valid, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=0) + else: + print("WARNING label num is high than 256") + gt_box = gt_bboxes[0:pad_max_number] + gt_label = gt_label[0:pad_max_number] + gt_valid = gt_valid[0:pad_max_number] + return (img, img_shape, gt_box[:, :4], gt_label, gt_valid) + +def preprocess_fn(image, box, is_training): + """Preprocess function for dataset.""" + def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid): + image_shape = image_shape[:2] + input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid + if config.keep_ratio: + input_data = rescale_column(*input_data) + else: + input_data = resize_column_test(*input_data) + input_data = pad_label(*input_data) + input_data = image_bgr_rgb(*input_data) + output_data = input_data + return output_data + + def _data_aug(image, box, is_training): + """Data augmentation function.""" + image_bgr = image.copy() + image_bgr[:, :, 0] = image[:, :, 2] + image_bgr[:, :, 1] = image[:, :, 1] + image_bgr[:, :, 2] = image[:, :, 0] + image_shape = image_bgr.shape[:2] + gt_box = box[:, :4] + gt_label = box[:, 4] + gt_valid = box[:, 4] + input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid + if not is_training: + return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid) + expand = (np.random.rand() < config.expand_ratio) + if expand: + input_data = expand_column(*input_data) + input_data = photo_crop_column(*input_data) + if config.keep_ratio: + input_data = rescale_column(*input_data) + else: + input_data = resize_column(*input_data) + input_data = pad_label(*input_data) + input_data = image_bgr_rgb(*input_data) + output_data = input_data + return output_data + + return _data_aug(image, box, is_training) + +def anno_parser(annos_str): + """Parse annotation from string to list.""" + annos = [] + for anno_str in annos_str: + anno = list(map(int, anno_str.strip().split(','))) + annos.append(anno) + return annos + +def filter_valid_data(image_dir, anno_path): + """Filter valid image file, which both in image_dir and anno_path.""" + image_files = [] + image_anno_dict = {} + if not os.path.isdir(image_dir): + raise RuntimeError("Path given is not valid.") + if not os.path.isfile(anno_path): + raise RuntimeError("Annotation file is not valid.") + + with open(anno_path, "rb") as f: + lines = f.readlines() + for line in lines: + line_str = line.decode("utf-8").strip() + line_split = str(line_str).split(' ') + file_name = line_split[0] + image_path = os.path.join(image_dir, file_name) + if os.path.isfile(image_path): + image_anno_dict[image_path] = anno_parser(line_split[1:]) + image_files.append(image_path) + return image_files, image_anno_dict + +def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8): + """Create MindRecord file.""" + mindrecord_dir = config.mindrecord_dir + mindrecord_path = os.path.join(mindrecord_dir, prefix) + writer = FileWriter(mindrecord_path, file_num) + image_files, image_anno_dict = create_icdar_test_label() + ctpn_json = { + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 6]}, + } + writer.add_schema(ctpn_json, "ctpn_json") + for image_name in image_files: + with open(image_name, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[image_name], dtype=np.int32) + row = {"image": img, "annotation": annos} + writer.write_raw_data([row]) + writer.commit() + +def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0, + is_training=True, num_parallel_workers=4): + """Creatr deeptext dataset with MindDataset.""" + ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\ + num_parallel_workers=8, shuffle=is_training) + decode = C.Decode() + ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=1) + compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training)) + hwc_to_chw = C.HWC2CHW() + normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375)) + type_cast0 = CC.TypeCast(mstype.float32) + type_cast1 = CC.TypeCast(mstype.float16) + type_cast2 = CC.TypeCast(mstype.int32) + type_cast3 = CC.TypeCast(mstype.bool_) + if is_training: + ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], + output_columns=["image", "image_shape", "box", "label", "valid_num"], + column_order=["image", "image_shape", "box", "label", "valid_num"], + num_parallel_workers=num_parallel_workers) + ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], + num_parallel_workers=12) + ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"], + num_parallel_workers=12) + else: + ds = ds.map(operations=compose_map_func, + input_columns=["image", "annotation"], + output_columns=["image", "image_shape", "box", "label", "valid_num"], + column_order=["image", "image_shape", "box", "label", "valid_num"], + num_parallel_workers=num_parallel_workers) + + ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"], + num_parallel_workers=24) + # transpose_column from python to c + ds = ds.map(operations=[type_cast1], input_columns=["image_shape"]) + ds = ds.map(operations=[type_cast1], input_columns=["box"]) + ds = ds.map(operations=[type_cast2], input_columns=["label"]) + ds = ds.map(operations=[type_cast3], input_columns=["valid_num"]) + ds = ds.batch(batch_size, drop_remainder=True) + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/official/cv/ctpn/src/lr_schedule.py b/model_zoo/official/cv/ctpn/src/lr_schedule.py new file mode 100644 index 0000000000..6c6134f388 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/lr_schedule.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr generator for deeptext""" +import math + +def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + learning_rate = float(init_lr) + lr_inc * current_step + return learning_rate + +def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): + base = float(current_step - warmup_steps) / float(decay_steps) + learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr + return learning_rate + +def dynamic_lr(config, base_step): + """dynamic learning rate generator""" + base_lr = config.base_lr + total_steps = int(base_step * config.total_epoch) + warmup_steps = config.warmup_step + lr = [] + for i in range(total_steps): + if i < warmup_steps: + lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio)) + else: + lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) + return lr diff --git a/model_zoo/official/cv/ctpn/src/network_define.py b/model_zoo/official/cv/ctpn/src/network_define.py new file mode 100644 index 0000000000..d31586c7d5 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/network_define.py @@ -0,0 +1,153 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""FasterRcnn training network wrapper.""" + +import time +import numpy as np +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.ops import functional as F +from mindspore.ops import composite as C +from mindspore import ParameterTuple +from mindspore.train.callback import Callback +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer + +time_stamp_init = False +time_stamp_first = 0 +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF terminating training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + + def __init__(self, per_print_times=1, rank_id=0): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.count = 0 + self.rpn_loss_sum = 0 + self.rpn_cls_loss_sum = 0 + self.rpn_reg_loss_sum = 0 + self.rank_id = rank_id + + global time_stamp_init, time_stamp_first + if not time_stamp_init: + time_stamp_first = time.time() + time_stamp_init = True + + def step_end(self, run_context): + cb_params = run_context.original_args() + rpn_loss = cb_params.net_outputs[0].asnumpy() + rpn_cls_loss = cb_params.net_outputs[1].asnumpy() + rpn_reg_loss = cb_params.net_outputs[2].asnumpy() + + self.count += 1 + self.rpn_loss_sum += float(rpn_loss) + self.rpn_cls_loss_sum += float(rpn_cls_loss) + self.rpn_reg_loss_sum += float(rpn_reg_loss) + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + + if self.count >= 1: + global time_stamp_first + time_stamp_current = time.time() + rpn_loss = self.rpn_loss_sum / self.count + rpn_cls_loss = self.rpn_cls_loss_sum / self.count + rpn_reg_loss = self.rpn_reg_loss_sum / self.count + loss_file = open("./loss_{}.log".format(self.rank_id), "a+") + loss_file.write("%lu epoch: %s step: %s ,rpn_loss: %.5f, rpn_cls_loss: %.5f, rpn_reg_loss: %.5f"% + (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, + rpn_loss, rpn_cls_loss, rpn_reg_loss)) + loss_file.write("\n") + loss_file.close() + +class LossNet(nn.Cell): + """FasterRcnn loss method""" + def construct(self, x1, x2, x3): + return x1 + +class WithLossCell(nn.Cell): + """ + Wrap the network with loss function to compute loss. + + Args: + backbone (Cell): The target network to wrap. + loss_fn (Cell): The loss function used to compute loss. + """ + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): + rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self._backbone(x, img_shape, gt_bboxe, gt_label, gt_num) + return self._loss_fn(rpn_loss, rpn_cls_loss, rpn_reg_loss) + + @property + def backbone_network(self): + """ + Get the backbone network. + + Returns: + Cell, return backbone network. + """ + return self._backbone + + +class TrainOneStepCell(nn.Cell): + """ + Network training package class. + + Append an optimizer to the training network after that the construct function + can be called to create the backward graph. + + Args: + network (Cell): The training network. + network_backbone (Cell): The forward network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default value is 1.0. + reduce_flag (bool): The reduce flag. Default value is False. + mean (bool): Allreduce method. Default value is False. + degree (int): Device number. Default value is None. + """ + def __init__(self, network, network_backbone, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None): + super(TrainOneStepCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.backbone = network_backbone + self.weights = ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = C.GradOperation(get_by_list=True, + sens_param=True) + self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32)) + self.reduce_flag = reduce_flag + if reduce_flag: + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, x, img_shape, gt_bboxe, gt_label, gt_num): + weights = self.weights + rpn_loss, _, _, rpn_cls_loss, rpn_reg_loss = self.backbone(x, img_shape, gt_bboxe, gt_label, gt_num) + grads = self.grad(self.network, weights)(x, img_shape, gt_bboxe, gt_label, gt_num, self.sens) + if self.reduce_flag: + grads = self.grad_reducer(grads) + return F.depend(rpn_loss, self.optimizer(grads)), rpn_cls_loss, rpn_reg_loss diff --git a/model_zoo/official/cv/ctpn/src/text_connector/__init__.py b/model_zoo/official/cv/ctpn/src/text_connector/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/ctpn/src/text_connector/connect_text_lines.py b/model_zoo/official/cv/ctpn/src/text_connector/connect_text_lines.py new file mode 100644 index 0000000000..171beca9a7 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/text_connector/connect_text_lines.py @@ -0,0 +1,65 @@ + +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================import numpy as np +import numpy as np +from src.text_connector.utils import clip_boxes, fit_y +from src.text_connector.get_successions import get_successions + +def connect_text_lines(text_proposals, scores, size): + """ + Connect text lines + + Args: + text_proposals(numpy.array): Predict text proposals. + scores(numpy.array): Bbox predicts scores. + size(numpy.array): Image size. + Returns: + text_recs(numpy.array): Text boxes after connect. + """ + graph = get_successions(text_proposals, scores, size) + text_lines = np.zeros((len(graph), 5), np.float32) + for index, indices in enumerate(graph): + text_line_boxes = text_proposals[list(indices)] + x0 = np.min(text_line_boxes[:, 0]) + x1 = np.max(text_line_boxes[:, 2]) + + offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 + + lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) + lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) + + # the score of a text line is the average score of the scores + # of all text proposals contained in the text line + score = scores[list(indices)].sum() / float(len(indices)) + + text_lines[index, 0] = x0 + text_lines[index, 1] = min(lt_y, rt_y) + text_lines[index, 2] = x1 + text_lines[index, 3] = max(lb_y, rb_y) + text_lines[index, 4] = score + + text_lines = clip_boxes(text_lines, size) + + text_recs = np.zeros((len(text_lines), 9), np.float) + index = 0 + for line in text_lines: + xmin, ymin, xmax, ymax = line[0], line[1], line[2], line[3] + text_recs[index, 0] = xmin + text_recs[index, 1] = ymin + text_recs[index, 2] = xmax + text_recs[index, 3] = ymax + text_recs[index, 4] = line[4] + index = index + 1 + return text_recs diff --git a/model_zoo/official/cv/ctpn/src/text_connector/detector.py b/model_zoo/official/cv/ctpn/src/text_connector/detector.py new file mode 100644 index 0000000000..6518d920fa --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/text_connector/detector.py @@ -0,0 +1,73 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from src.config import config +from src.text_connector.utils import nms +from src.text_connector.connect_text_lines import connect_text_lines + +def filter_proposal(proposals, scores): + """ + Filter text proposals + + Args: + proposals(numpy.array): Text proposals. + Returns: + proposals(numpy.array): Text proposals after filter. + """ + inds = np.where(scores > config.text_proposals_min_scores)[0] + keep_proposals = proposals[inds] + keep_scores = scores[inds] + sorted_inds = np.argsort(keep_scores.ravel())[::-1] + keep_proposals, keep_scores = keep_proposals[sorted_inds], keep_scores[sorted_inds] + nms_inds = nms(np.hstack((keep_proposals, keep_scores)), config.text_proposals_nms_thresh) + keep_proposals, keep_scores = keep_proposals[nms_inds], keep_scores[nms_inds] + return keep_proposals, keep_scores + +def filter_boxes(boxes): + """ + Filter text boxes + + Args: + boxes(numpy.array): Text boxes. + Returns: + boxes(numpy.array): Text boxes after filter. + """ + heights = np.zeros((len(boxes), 1), np.float) + widths = np.zeros((len(boxes), 1), np.float) + scores = np.zeros((len(boxes), 1), np.float) + index = 0 + for box in boxes: + widths[index] = abs(box[2] - box[0]) + heights[index] = abs(box[3] - box[1]) + scores[index] = abs(box[4]) + index += 1 + return np.where((widths / heights > config.min_ratio) & (scores > config.line_min_score) &\ + (widths > (config.text_proposals_width * config.min_num_proposals)))[0] + +def detect(text_proposals, scores, size): + """ + Detect text boxes + + Args: + text_proposals(numpy.array): Predict text proposals. + scores(numpy.array): Bbox predicts scores. + size(numpy.array): Image size. + Returns: + boxes(numpy.array): Text boxes after connect. + """ + keep_proposals, keep_scores = filter_proposal(text_proposals, scores) + connect_boxes = connect_text_lines(keep_proposals, keep_scores, size) + boxes = connect_boxes[filter_boxes(connect_boxes)] + return boxes diff --git a/model_zoo/official/cv/ctpn/src/text_connector/get_successions.py b/model_zoo/official/cv/ctpn/src/text_connector/get_successions.py new file mode 100644 index 0000000000..007a1b6325 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/text_connector/get_successions.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from src.config import config +from src.text_connector.utils import overlaps_v, size_similarity + +def get_successions(text_proposals, scores, im_size): + """ + Get successions text boxes. + + Args: + text_proposals(numpy.array): Predict text proposals. + scores(numpy.array): Bbox predicts scores. + size(numpy.array): Image size. + Returns: + sub_graph(list): Proposals graph. + """ + bboxes_table = [[] for _ in range(int(im_size[1]))] + for index, box in enumerate(text_proposals): + bboxes_table[int(box[0])].append(index) + graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) + for index, box in enumerate(text_proposals): + successions_left = [] + for left in range(int(box[0]) + 1, min(int(box[0]) + config.max_horizontal_gap + 1, im_size[1])): + adj_box_indices = bboxes_table[left] + for adj_box_index in adj_box_indices: + if meet_v_iou(text_proposals, adj_box_index, index): + successions_left.append(adj_box_index) + if successions_left: + break + if not successions_left: + continue + succession_index = successions_left[np.argmax(scores[successions_left])] + box_right = text_proposals[succession_index] + succession_right = [] + for right in range(int(box_right[0]) - 1, max(int(box_right[0] - config.max_horizontal_gap), 0) - 1, -1): + adj_box_indices = bboxes_table[right] + for adj_box_index in adj_box_indices: + if meet_v_iou(text_proposals, adj_box_index, index): + succession_right.append(adj_box_index) + if succession_right: + break + if scores[index] >= np.max(scores[succession_right]): + graph[index, succession_index] = True + sub_graph = get_sub_graph(graph) + return sub_graph + +def get_sub_graph(graph): + """ + Get successions text boxes. + + Args: + graph(numpy.array): proposal graph + Returns: + sub_graph(list): Proposals graph after connect. + """ + sub_graphs = [] + for index in range(graph.shape[0]): + if not graph[:, index].any() and graph[index, :].any(): + v = index + sub_graphs.append([v]) + while graph[v, :].any(): + v = np.where(graph[v, :])[0][0] + sub_graphs[-1].append(v) + return sub_graphs + +def meet_v_iou(text_proposals, index1, index2): + """ + Calculate vertical iou. + + Args: + text_proposals(numpy.array): tex proposals + index1(int): text_proposal index + tindex2(int): text proposal index + Returns: + sub_graph(list): Proposals graph after connect. + """ + heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 + return overlaps_v(text_proposals, index1, index2) >= config.min_v_overlaps and \ + size_similarity(heights, index1, index2) >= config.min_size_sim diff --git a/model_zoo/official/cv/ctpn/src/text_connector/utils.py b/model_zoo/official/cv/ctpn/src/text_connector/utils.py new file mode 100644 index 0000000000..fcfa795210 --- /dev/null +++ b/model_zoo/official/cv/ctpn/src/text_connector/utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + + +def threshold(coords, min_, max_): + return np.maximum(np.minimum(coords, max_), min_) + +def clip_boxes(boxes, im_shape): + """ + Clip boxes to image boundaries. + + Args: + boxes(numpy.array):bounding box. + im_shape(numpy.array): image shape. + + Return: + boxes(numpy.array):boundding box after clip. + """ + boxes[:, 0::2] = threshold(boxes[:, 0::2], 0, im_shape[1] - 1) + boxes[:, 1::2] = threshold(boxes[:, 1::2], 0, im_shape[0] - 1) + return boxes + +def overlaps_v(text_proposals, index1, index2): + """ + Calculate vertical overlap ratio. + + Args: + text_proposals(numpy.array): Text proposlas. + index1(int): First text proposal. + index2(int): Second text proposal. + + Return: + overlap(float32): vertical overlap. + """ + h1 = text_proposals[index1][3] - text_proposals[index1][1] + 1 + h2 = text_proposals[index2][3] - text_proposals[index2][1] + 1 + y0 = max(text_proposals[index2][1], text_proposals[index1][1]) + y1 = min(text_proposals[index2][3], text_proposals[index1][3]) + return max(0, y1 - y0 + 1) / min(h1, h2) + +def size_similarity(heights, index1, index2): + """ + Calculate vertical size similarity ratio. + + Args: + heights(numpy.array): Text proposlas heights. + index1(int): First text proposal. + index2(int): Second text proposal. + + Return: + overlap(float32): vertical overlap. + """ + h1 = heights[index1] + h2 = heights[index2] + return min(h1, h2) / max(h1, h2) + +def fit_y(X, Y, x1, x2): + if np.sum(X == X[0]) == len(X): + return Y[0], Y[0] + p = np.poly1d(np.polyfit(X, Y, 1)) + return p(x1), p(x2) + +def nms(bboxs, thresh): + """ + Args: + text_proposals(numpy.array): tex proposals + index1(int): text_proposal index + tindex2(int): text proposal index + """ + x1, y1, x2, y2, scores = np.split(bboxs, 5, axis=1) + x1 = bboxs[:, 0] + y1 = bboxs[:, 1] + x2 = bboxs[:, 2] + y2 = bboxs[:, 3] + scores = bboxs[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + num_dets = bboxs.shape[0] + suppressed = np.zeros(num_dets, dtype=np.int32) + keep = [] + for _i in range(num_dets): + i = order[_i] + if suppressed[i] == 1: + continue + keep.append(i) + x1_i = x1[i] + y1_i = y1[i] + x2_i = x2[i] + y2_i = y2[i] + area_i = areas[i] + for _j in range(_i + 1, num_dets): + j = order[_j] + if suppressed[j] == 1: + continue + x1_j = max(x1_i, x1[j]) + y1_j = max(y1_i, y1[j]) + x2_j = min(x2_i, x2[j]) + y2_j = min(y2_i, y2[j]) + w = max(0.0, x2_j - x1_j + 1) + h = max(0.0, y2_j - y1_j + 1) + inter = w*h + overlap = inter / (area_i+areas[j]-inter) + if overlap >= thresh: + suppressed[j] = 1 + return keep diff --git a/model_zoo/official/cv/ctpn/train.py b/model_zoo/official/cv/ctpn/train.py new file mode 100644 index 0000000000..eb2bc57539 --- /dev/null +++ b/model_zoo/official/cv/ctpn/train.py @@ -0,0 +1,119 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""train CTPN and get checkpoint files.""" +import os +import time +import argparse +import ast +import mindspore.common.dtype as mstype +from mindspore import context, Tensor +from mindspore.communication.management import init +from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor +from mindspore.train import Model +from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn import Momentum +from mindspore.common import set_seed +from src.ctpn import CTPN +from src.config import config, pretrain_config, finetune_config +from src.dataset import create_ctpn_dataset +from src.lr_schedule import dynamic_lr +from src.network_define import LossCallBack, LossNet, WithLossCell, TrainOneStepCell + +set_seed(1) + +parser = argparse.ArgumentParser(description="CTPN training") +parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, help="Run distribute, default: false.") +parser.add_argument("--pre_trained", type=str, default="", help="Pretrained file path.") +parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.") +parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default: 1.") +parser.add_argument("--rank_id", type=int, default=0, help="Rank id, default: 0.") +parser.add_argument("--task_type", type=str, default="Pretraining",\ + choices=['Pretraining', 'Finetune'], help="task type, default:Pretraining") +args_opt = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id, save_graphs=True) + +if __name__ == '__main__': + if args_opt.run_distribute: + rank = args_opt.rank_id + device_num = args_opt.device_num + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + init() + else: + rank = 0 + device_num = 1 + if args_opt.task_type == "Pretraining": + print("Start to do pretraining") + mindrecord_file = config.pretraining_dataset_file + training_cfg = pretrain_config + else: + print("Start to do finetune") + mindrecord_file = config.finetune_dataset_file + training_cfg = finetune_config + + print("CHECKING MINDRECORD FILES ...") + while not os.path.exists(mindrecord_file + ".db"): + time.sleep(5) + + print("CHECKING MINDRECORD FILES DONE!") + + loss_scale = float(config.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as ctpn_pretrain.mindrecord0. + dataset = create_ctpn_dataset(mindrecord_file, repeat_num=1,\ + batch_size=config.batch_size, device_num=device_num, rank_id=rank) + dataset_size = dataset.get_dataset_size() + net = CTPN(config=config, is_training=True) + net = net.set_train() + + load_path = args_opt.pre_trained + if args_opt.task_type == "Pretraining": + print("load backbone vgg16 ckpt {}".format(args_opt.pre_trained)) + param_dict = load_checkpoint(load_path) + for item in list(param_dict.keys()): + if not item.startswith('vgg16_feature_extractor'): + param_dict.pop(item) + load_param_into_net(net, param_dict) + else: + if load_path != "": + print("load pretrain ckpt {}".format(args_opt.pre_trained)) + param_dict = load_checkpoint(load_path) + load_param_into_net(net, param_dict) + loss = LossNet() + lr = Tensor(dynamic_lr(training_cfg, dataset_size), mstype.float32) + opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,\ + weight_decay=config.weight_decay, loss_scale=config.loss_scale) + net_with_loss = WithLossCell(net, loss) + if args_opt.run_distribute: + net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale, reduce_flag=True, + mean=True, degree=device_num) + else: + net = TrainOneStepCell(net_with_loss, net, opt, sens=config.loss_scale) + + time_cb = TimeMonitor(data_size=dataset_size) + loss_cb = LossCallBack(rank_id=rank) + cb = [time_cb, loss_cb] + if config.save_checkpoint: + ckptconfig = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*dataset_size, + keep_checkpoint_max=config.keep_checkpoint_max) + save_checkpoint_path = os.path.join(config.save_checkpoint_path, "ckpt_" + str(rank) + "/") + ckpoint_cb = ModelCheckpoint(prefix='ctpn', directory=save_checkpoint_path, config=ckptconfig) + cb += [ckpoint_cb] + + model = Model(net) + model.train(training_cfg.total_epoch, dataset, callbacks=cb, dataset_sink_mode=True)