!10924 deeplabv3 bugfix: 1. add CPU training usage in readme file; 2. add dataset list file genrate script.

From: @caojian05
Reviewed-by: @wuxuejian,@oacjiewen
Signed-off-by: @wuxuejian
pull/10924/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c5532bcdf0

@ -50,6 +50,8 @@ JPEGImages/00004.jpg SegmentationClassGray/00004.png
......
```
You can also generate the list file automatically by run script: `python get_dataset_lst.py --data_root=/PATH/TO/DATA`
- Configure and run build_data.sh to convert dataset to mindrecords. Arguments in scripts/build_data.sh:
```shell
@ -164,10 +166,12 @@ run_eval_s8_multiscale_flip.sh
├── run_eval_s8_multiscale.sh # launch ascend evaluation with multiscale in s8 structure
├── run_eval_s8_multiscale_filp.sh # launch ascend evaluation with multiscale and filp in s8 structure
├── run_standalone_train.sh # launch ascend standalone training(1 pc)
├── run_standalone_train_cpu.sh # launch CPU standalone training
├── src
├── data
├── dataset.py # mindrecord data generator
├── build_seg_data.py # data preprocessing
├── get_dataset_lst.py # dataset list file generator
├── loss
├── loss.py # loss definition for deeplabv3
├── nets
@ -189,6 +193,7 @@ Default configuration
```shell
"data_file":"/PATH/TO/MINDRECORD_NAME" # dataset path
"device_target":Ascend # device target
"train_epochs":300 # total epochs
"batch_size":32 # batch size of input tensor
"crop_size":513 # crop size
@ -238,7 +243,7 @@ For 8 devices training, training steps are as follows:
1. Train s16 with vocaug dataset, finetuning from resnet101 pretrained model, script is as follows:
```python
```shell
# run_distribute_train_s16_r1.sh
for((i=0;i<=$RANK_SIZE-1;i++));
do
@ -328,8 +333,34 @@ do
done
```
#### Running on CPU
For CPU training, please config parameters, training script is as follows:
```shell
# run_standalone_train_cpu.sh
python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
--device_target=CPU \
--train_dir=${train_path}/ckpt \
--train_epochs=200 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.015 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--save_steps=1500 \
--keep_checkpoint_max=200 >log 2>&1 &
```
### Result
#### Running on Ascend
- Training vocaug in s16 structure
```shell
@ -386,6 +417,17 @@ epoch time: 5962.164 ms, per step time: 542.015 ms
...
```
#### Running on CPU
- Training voctrain in s16 structure
```bash
epoch: 1 step: 1, loss is 3.655448
epoch: 2 step: 1, loss is 1.5531876
epoch: 3 step: 1, loss is 1.5099041
...
```
## [Evaluation Process](#contents)
### Usage
@ -438,7 +480,7 @@ Note: There OS is output stride, and MS is multiscale.
# [Model Description](#contents)
## [Performance](#contents
## [Performance](#contents)
### Evaluation Performance

@ -63,6 +63,8 @@ Pascal VOC数据集和语义边界数据集Semantic Boundaries DatasetSBD
......
```
你也可以通过运行脚本:`python get_dataset_lst.py --data_root=/PATH/TO/DATA` 来自动生成数据清单文件。
- 配置并运行build_data.sh将数据集转换为MindRecords。scripts/build_data.sh中的参数
```
@ -177,10 +179,12 @@ run_eval_s8_multiscale_flip.sh
├── run_eval_s8_multiscale.sh # 使用多尺度s8结构启动Ascend评估
├── run_eval_s8_multiscale_filp.sh # 使用多尺度和翻转s8结构启动Ascend评估
├── run_standalone_train.sh # 启动Ascend单机训练单卡
├── run_standalone_train_cpu.sh # 启动CPU单机训练
├── src
├── data
├── dataset.py # 生成MindRecord数据
├── build_seg_data.py # 数据预处理
├── get_dataset_lst.py # 生成数据清单文件
├── loss
├── loss.py # DeepLabV3的损失定义
├── nets
@ -202,6 +206,7 @@ run_eval_s8_multiscale_flip.sh
```bash
"data_file":"/PATH/TO/MINDRECORD_NAME" # 数据集路径
"device_target":Ascend # 训练后端类型
"train_epochs":300 # 总轮次数
"batch_size":32 # 输入张量的批次大小
"crop_size":513 # 裁剪大小
@ -342,8 +347,34 @@ do
done
```
#### CPU环境运行
按以下样例配置训练参数运行CPU训练脚本
```shell
# run_standalone_train_cpu.sh
python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \
--device_target=CPU \
--train_dir=${train_path}/ckpt \
--train_epochs=200 \
--batch_size=32 \
--crop_size=513 \
--base_lr=0.015 \
--lr_type=cos \
--min_scale=0.5 \
--max_scale=2.0 \
--ignore_label=255 \
--num_classes=21 \
--model=deeplab_v3_s16 \
--ckpt_pre_trained=/PATH/TO/PRETRAIN_MODEL \
--save_steps=1500 \
--keep_checkpoint_max=200 >log 2>&1 &
```
### 结果
#### Ascend处理器环境运行
- 使用s16结构训练VOCaug
```bash
@ -400,6 +431,17 @@ Epoch time: 5962.164, per step time: 542.015
...
```
#### CPU环境运行
- 使用s16结构训练VOCtrain
```bash
epoch: 1 step: 1, loss is 3.655448
epoch: 2 step: 1, loss is 1.5531876
epoch: 3 step: 1, loss is 1.5099041
...
```
## 评估过程
### 用法
@ -470,7 +512,7 @@ python ${train_code_path}/eval.py --data_root=/PATH/TO/DATA \
| 损失 | 0.0065883575 |
| 速度 | 31毫秒/步单卡s8<br> 234毫秒/步8卡s8 |
| 微调检查点 | 443M .ckpt文件 |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/office/cv/deeplabv3) |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/deeplabv3) |
# 随机情况说明

@ -0,0 +1,142 @@
import argparse
import os
import numpy as np
import scipy.io
from PIL import Image
parser = argparse.ArgumentParser('dataset list generator')
parser.add_argument("--data_dir", type=str, default='./', help='where dataset stored.')
args, _ = parser.parse_known_args()
data_dir = args.data_dir
print("Data dir is:", data_dir)
#
VOC_IMG_DIR = os.path.join(data_dir, 'VOCdevkit/VOC2012/JPEGImages')
VOC_ANNO_DIR = os.path.join(data_dir, 'VOCdevkit/VOC2012/SegmentationClass')
VOC_ANNO_GRAY_DIR = os.path.join(data_dir, 'VOCdevkit/VOC2012/SegmentationClassGray')
VOC_TRAIN_TXT = os.path.join(data_dir, 'VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt')
VOC_VAL_TXT = os.path.join(data_dir, 'VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt')
SBD_ANNO_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/cls')
SBD_IMG_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/img')
SBD_ANNO_PNG_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/cls_png')
SBD_ANNO_GRAY_DIR = os.path.join(data_dir, 'benchmark_RELEASE/dataset/cls_png_gray')
SBD_TRAIN_TXT = os.path.join(data_dir, 'benchmark_RELEASE/dataset/train.txt')
SBD_VAL_TXT = os.path.join(data_dir, 'benchmark_RELEASE/dataset/val.txt')
VOC_TRAIN_LST_TXT = os.path.join(data_dir, 'voc_train_lst.txt')
VOC_VAL_LST_TXT = os.path.join(data_dir, 'voc_val_lst.txt')
VOC_AUG_TRAIN_LST_TXT = os.path.join(data_dir, 'vocaug_train_lst.txt')
def __get_data_list(data_list_file):
with open(data_list_file, mode='r') as f:
return f.readlines()
def conv_voc_colorpng_to_graypng():
if not os.path.exists(VOC_ANNO_GRAY_DIR):
os.makedirs(VOC_ANNO_GRAY_DIR)
for ann in os.listdir(VOC_ANNO_DIR):
ann_im = Image.open(os.path.join(VOC_ANNO_DIR, ann))
ann_im = Image.fromarray(np.array(ann_im))
ann_im.save(os.path.join(VOC_ANNO_GRAY_DIR, ann))
def __gen_palette(cls_nums=256):
palette = np.zeros((cls_nums, 3), dtype=np.uint8)
for i in range(cls_nums):
lbl = i
j = 0
while lbl:
palette[i, 0] |= (((lbl >> 0) & 1) << (7 - j))
palette[i, 1] |= (((lbl >> 1) & 1) << (7 - j))
palette[i, 2] |= (((lbl >> 2) & 1) << (7 - j))
lbl >>= 3
j += 1
return palette.flatten()
def conv_sbd_mat_to_png():
if not os.path.exists(SBD_ANNO_PNG_DIR):
os.makedirs(SBD_ANNO_PNG_DIR)
if not os.path.exists(SBD_ANNO_GRAY_DIR):
os.makedirs(SBD_ANNO_GRAY_DIR)
palette = __gen_palette()
for an in os.listdir(SBD_ANNO_DIR):
img_id = an[:-4]
mat = scipy.io.loadmat(os.path.join(SBD_ANNO_DIR, an))
anno = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
anno_png = Image.fromarray(anno)
# save to gray png
anno_png.save(os.path.join(SBD_ANNO_GRAY_DIR, img_id + '.png'))
# save to color png use palette
anno_png.putpalette(palette)
anno_png.save(os.path.join(SBD_ANNO_PNG_DIR, img_id + '.png'))
def create_voc_train_lst_txt():
voc_train_data_lst = __get_data_list(VOC_TRAIN_TXT)
with open(VOC_TRAIN_LST_TXT, mode='w') as f:
for id_ in voc_train_data_lst:
id_ = id_.strip()
img_ = os.path.join(VOC_IMG_DIR, id_ + '.jpg')
anno_ = os.path.join(VOC_ANNO_GRAY_DIR, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')
def create_voc_val_lst_txt():
voc_val_data_lst = __get_data_list(VOC_VAL_TXT)
with open(VOC_VAL_LST_TXT, mode='w') as f:
for id_ in voc_val_data_lst:
id_ = id_.strip()
img_ = os.path.join(VOC_IMG_DIR, id_ + '.jpg')
anno_ = os.path.join(VOC_ANNO_GRAY_DIR, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')
def create_voc_train_aug_lst_txt():
voc_train_data_lst = __get_data_list(VOC_TRAIN_TXT)
voc_val_data_lst = __get_data_list(VOC_VAL_TXT)
sbd_train_data_lst = __get_data_list(SBD_TRAIN_TXT)
sbd_val_data_lst = __get_data_list(SBD_VAL_TXT)
with open(VOC_AUG_TRAIN_LST_TXT, mode='w') as f:
for id_ in sbd_train_data_lst + sbd_val_data_lst:
if id_ in voc_train_data_lst + voc_val_data_lst:
continue
id_ = id_.strip()
img_ = os.path.join(SBD_ANNO_DIR, id_ + '.jpg')
anno_ = os.path.join(SBD_ANNO_GRAY_DIR, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')
for id_ in voc_train_data_lst:
id_ = id_.strip()
img_ = os.path.join(VOC_IMG_DIR, id_ + '.jpg')
anno_ = os.path.join(VOC_ANNO_GRAY_DIR, id_ + '.png')
f.write(img_ + ' ' + anno_ + '\n')
if __name__ == '__main__':
print('converting voc color png to gray png ...')
conv_voc_colorpng_to_graypng()
print('converting done.')
create_voc_train_lst_txt()
print('generating voc train list success.')
create_voc_val_lst_txt()
print('generating voc val list success.')
print('converting sbd annotations to png ...')
conv_sbd_mat_to_png()
print('converting done')
create_voc_train_aug_lst_txt()
print('generating voc train aug list success.')
Loading…
Cancel
Save