@ -30,6 +30,7 @@
<!-- /TOC -->
# CNN+CTC描述
本文描述了对场景文本识别( STR) 的三个主要贡献。
首先检查训练和评估数据集不一致的内容,以及导致的性能差距。
再引入一个统一的四阶段STR框架, 目前大多数STR模型都能够适应这个框架。
@ -40,6 +41,7 @@
[论文 ](https://arxiv.org/abs/1904.01906 ): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
# 模型架构
示例: 在MindSpore上使用MJSynth和SynthText数据集训练CNN+CTC模型进行文本识别。
# 数据集
@ -47,14 +49,18 @@
[MJSynth ](https://www.robots.ox.ac.uk/~vgg/data/text/ )和[SynthText](https://github.com/ankush-me/SynthText)数据集用于模型训练。[The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset)数据集用于评估。
- 步骤1:
所有数据集均经过预处理,以.lmdb格式存储, 点击[**此处**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt)可下载。
- 步骤2:
解压下载的文件, 重命名MJSynth数据集为MJ, SynthText数据集为ST, IIIT数据集为IIIT。
- 步骤3:
将上述三个数据集移至`cnctc_data`文件夹中,结构如下:
```
```python
|--- CNNCTC/
|--- cnnctc_data/
|--- ST/
@ -71,8 +77,10 @@
```
- 步骤4:
预处理数据集:
```
```shell
python src/preprocess_dataset.py
```
@ -87,22 +95,24 @@ python src/preprocess_dataset.py
# 环境要求
- 硬件( Ascend)
- 硬件(Ascend)
- 准备Ascend或GPU处理器搭建硬件环境。如需试用昇腾处理器, 请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com, 审核通过即可获得资源。
- 框架
- [MindSpore ](https://www.mindspore.cn/install )
- 如需查看详情,请参见如下资源:
- [MindSpore教程 ](https://www.mindspore.cn/tutorial/zh-CN/master/index.html )
- [MindSpore API](https://www.mindspore.cn/ api/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/ doc/ api_python /zh-CN/master/index.html)
# 快速入门
- 安装依赖:
```
```python
pip install lmdb
pip install Pillow
pip install tqdm
@ -111,19 +121,19 @@ pip install six
- 单机训练:
```
```shell
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
```
- 分布式训练:
```
```shell
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
```
- 评估:
```
```shell
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
```
@ -132,7 +142,8 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
## 脚本及样例代码
完整代码结构如下:
```
```python
|--- CNNCTC/
|---README.md // CNN+CTC相关描述
|---train.py // 训练脚本
@ -154,37 +165,40 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
```
## 脚本参数
在`config.py`中可以同时配置训练参数和评估参数。
参数:
* `--CHARACTER` :字符标签。
* `--NUM_CLASS` : 类别数, 包含所有字符标签和CTCLoss的< blank > 标签。
* `--HIDDEN_SIZE` :模型隐藏大小。
* `--FINAL_FEATURE_WIDTH` :特性的数量。
* `--IMG_H` :输入图像高度。
* `--IMG_W` :输入图像宽度。
* `--TRAIN_DATASET_PATH` :训练数据集的路径。
* `--TRAIN_DATASET_INDEX_PATH` :决定顺序的训练数据集索引文件的路径。
* `--TRAIN_BATCH_SIZE` :训练批次大小。在批次大小和索引文件中,必须确保输入数据是固定的形状。
* `--TRAIN_DATASET_SIZE` :训练数据集大小。
* `--TEST_DATASET_PATH` :测试数据集的路径。
* `--TEST_BATCH_SIZE` :测试批次大小。
* `--TRAIN_EPOCHS` :总训练轮次。
* `--CKPT_PATH` :模型检查点文件路径,可用于恢复训练和评估。
* `--SAVE_PATH` :模型检查点文件保存路径。
* `--LR` :单机训练学习率。
* `--LR_PARA` :分布式训练学习率。
* `--Momentum` :动量。
* `--LOSS_SCALE` :损失放大,避免梯度下溢。
* `--SAVE_CKPT_PER_N_STEP` : 每N步保存模型检查点文件。
* `--KEEP_CKPT_MAX_NUM` :模型检查点文件保存数量上限。
- `--CHARACTER` :字符标签。
- `--NUM_CLASS` : 类别数, 包含所有字符标签和CTCLoss的< blank > 标签。
- `--HIDDEN_SIZE` :模型隐藏大小。
- `--FINAL_FEATURE_WIDTH` :特性的数量。
- `--IMG_H` :输入图像高度。
- `--IMG_W` :输入图像宽度。
- `--TRAIN_DATASET_PATH` :训练数据集的路径。
- `--TRAIN_DATASET_INDEX_PATH` :决定顺序的训练数据集索引文件的路径。
- `--TRAIN_BATCH_SIZE` :训练批次大小。在批次大小和索引文件中,必须确保输入数据是固定的形状。
- `--TRAIN_DATASET_SIZE` :训练数据集大小。
- `--TEST_DATASET_PATH` :测试数据集的路径。
- `--TEST_BATCH_SIZE` :测试批次大小。
- `--TRAIN_EPOCHS` :总训练轮次。
- `--CKPT_PATH` :模型检查点文件路径,可用于恢复训练和评估。
- `--SAVE_PATH` :模型检查点文件保存路径。
- `--LR` :单机训练学习率。
- `--LR_PARA` :分布式训练学习率。
- `--Momentum` :动量。
- `--LOSS_SCALE` :损失放大,避免梯度下溢。
- `--SAVE_CKPT_PER_N_STEP` : 每N步保存模型检查点文件。
- `--KEEP_CKPT_MAX_NUM` :模型检查点文件保存数量上限。
## 训练过程
### 训练
- 单机训练:
```
```shell
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
```
@ -193,12 +207,13 @@ bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
`$PRETRAINED_CKPT` 为模型检查点的路径,**可选**。如果值为none, 模型将从头开始训练。
- 分布式训练:
```
```shell
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
```
结果和检查点分别写入设备`i`的`./train_parallel_{i}`文件夹。
日志可以在`./train_parallel_{i}/log_{i}.log`中找到,损失值记录在`./train_parallel_{i}/loss.log`中。
日志可以在`./train_parallel_{i}/log_{i}.log`中找到,损失值记录在`./train_parallel_{i}/loss.log`中。
在Ascend上运行分布式任务时需要`$RANK_TABLE_FILE`。
`$PATH_TO_CHECKPOINT` 为模型检查点的路径,**可选**。如果值为none, 模型将从头开始训练。
@ -207,8 +222,7 @@ bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
训练结果保存在示例路径中, 文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果, 如下所示。
```
```python
# 分布式训练结果( 8P)
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203
@ -235,7 +249,7 @@ epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184
- 评估:
```
```shell
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
```
@ -260,7 +274,7 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
| 速度 | 1卡: 300毫秒/步; 8卡: 310毫秒/步 |
| 总时间 | 1卡: 18小时; 8卡: 2.3小时 |
| 参数(M) | 177 |
| 脚本 | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/office/cv/cnnctc |
| 脚本 | < https: / / gitee . com / mindspore / mindspore / tree / master / model_zoo / official / cv / cnnctc > |
### 评估性能
@ -277,13 +291,14 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
| 推理模型 | 675M( .ckpt文件) |
## 用法
### 推理
如果您需要在GPU、Ascend 910、Ascend 310等多个硬件平台上使用训练好的模型进行推理, 请参考此[链接](https://www.mindspore.cn/tutory/zh-CN/master/advanced_use/network_migration.html)。以下为简单示例:
- Ascend处理器环境运行
```
```python
# 设置上下文
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
context.set_context(device_id=cfg.device_id)
@ -314,7 +329,7 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
- Ascend处理器环境运行
```
```python
# 加载数据集
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
@ -349,4 +364,5 @@ bash scripts/run_eval_ascend.sh $TRAINED_CKPT
```
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。