!12018 MindSpore社区网络模型征集活动——DenseNet-121

From: @fireinthehole1024
Reviewed-by: 
Signed-off-by:
pull/12018/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a36485fdb4

File diff suppressed because it is too large Load Diff

@ -63,8 +63,8 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
# 环境要求
- 硬件Ascend
- 准备Ascend AI处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。
- 硬件Ascend/GPU
- 准备Ascend或GPU处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
@ -75,6 +75,8 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 训练示例
python train.py --data_dir /PATH/TO/DATASET --pretrained /PATH/TO/PRETRAINED_CKPT --is_distributed 0 > train.log 2>&1 &
@ -94,6 +96,22 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)
- GPU处理器环境运行
```python
# 训练示例
export CUDA_VISIBLE_DEVICES=0
python train.py --data_dir=[DATASET_PATH] --is_distributed=0 --device_target='GPU' > train.log 2>&1 &
# 分布式训练示例
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 [DATASET_PATH]
# 评估示例
python eval.py --data_dir=[DATASET_PATH] --device_target='GPU' --pretrained=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_distribute_eval_gpu.sh 1 0 [DATASET_PATH] [CHECKPOINT_PATH]
```
# 脚本说明
## 脚本及样例代码
@ -105,7 +123,9 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
├── README.md // DenseNet-121相关说明
├── scripts
│ ├── run_distribute_train.sh // Ascend分布式shell脚本
│ ├── run_distribute_train_gpu.sh // GPU分布式shell脚本
│ ├── run_distribute_eval.sh // Ascend评估shell脚本
│ ├── run_distribute_eval_gpu.sh // GPU评估shell脚本
├── src
│ ├── datasets // 数据集处理函数
│ ├── losses
@ -176,6 +196,15 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
...
```
- GPU处理器环境运行
```python
export CUDA_VISIBLE_DEVICES=0
python train.py --data_dir=[DATASET_PATH] --is_distributed=0 --device_target='GPU' > train.log 2>&1 &
```
以上python命令在后台运行在`output/202x-xx-xx_time_xx_xx/`目录下生成日志和模型检查点。
### 分布式训练
- Ascend处理器环境运行
@ -197,6 +226,15 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
...
```
- GPU处理器环境运行
```bash
cd scripts
sh run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 [DATASET_PATH]
```
上述shell脚本将在后台进行分布式训练。可以通过文件`train[X]/output/202x-xx-xx_time_xx_xx_xx/`查看结果日志和模型检查点。
## 评估过程
### 评估
@ -218,35 +256,52 @@ DenseNet-121构建在4个密集连接块上。各个密集块中每个层都
2020-08-24 09:21:50,551:INFO:after allreduce eval: top5_correct=46224, tot=49920, acc=92.60%
```
- GPU处理器环境
运行以下命令进行评估。
```eval
python eval.py --data_dir=[DATASET_PATH] --device_target='GPU' --pretrained=[CHECKPOINT_PATH] > eval.log 2>&1 &
OR
sh run_distribute_eval_gpu.sh 1 0 [DATASET_PATH] [CHECKPOINT_PATH]
```
上述python命令在后台运行。可以通过“eval/eval.log”文件查看结果。测试数据集的准确率如下
```log
2021-02-04 14:20:50,551:INFO:after allreduce eval: top1_correct=37637, tot=49984, acc=75.30%
2021-02-04 14:20:50,551:INFO:after allreduce eval: top5_correct=46370, tot=49984, acc=92.77%
```
# 模型描述
## 性能
### 训练准确率结果
| 参数 | DenseNet |
| ------------------- | --------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020/9/15 |
| MindSpore版本 | 1.0.0 |
| 数据集 | ImageNet |
| 轮次 | 120 |
| 输出 | 概率 |
| 训练性能 | Top175.13% Top592.57% |
| 参数 | Ascend | GPU |
| ------------------- | -------------------------- | -------------------------- |
| 模型版本 | Inception V1 | Inception V1 |
| 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 2020/9/15 | 2021/2/4 |
| MindSpore版本 | 1.0.0 | 1.1.1 |
| 数据集 | ImageNet | ImageNet |
| 轮次 | 120 | 120 |
| 输出 | 概率 | 概率 |
| 训练性能 | Top175.13%Top592.57% | Top175.30%; Top592.77% |
### 训练性能结果
| 参数 | DenseNet |
| ------------------- | --------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020/9/15 |
| MindSpore版本 | 1.0.0 |
| 数据集 | ImageNet |
| batch_size | 32 |
| 输出 | 概率 |
| 速度 | 单卡760 img/s8卡6000 img/s |
| 参数 | Ascend | GPU |
| ------------------- | -------------------------------- | -------------------------------- |
| 模型版本 | Inception V1 | Inception V1 |
| 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 2020/9/15 | 2021/2/4 |
| MindSpore版本 | 1.0.0 | 1.1.1 |
| 数据集 | ImageNet | ImageNet |
| batch_size | 32 | 32 |
| 输出 | 概率 | 概率 |
| 速度 | 单卡760 img/s8卡6000 img/s | 单卡161 img/s8卡1288 img/s |
# 随机情况说明

@ -38,10 +38,6 @@ from src.datasets import classification_dataset
from src.network import DenseNet121
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True, device_id=devid)
class ParameterReduce(nn.Cell):
"""
@ -83,6 +79,9 @@ def parse_args(cloud_args=None):
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
# platform
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
@ -114,6 +113,42 @@ def merge_args(args, cloud_args):
args_dict[key] = val
return args
def generate_results(model, rank, group_size, top1_correct, top5_correct, img_tot):
model_md5 = model.replace('/', '')
tmp_dir = '../cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) \
or not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break
top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
return [[top1_correct_all], [top5_correct_all], [img_tot_all]]
def test(cloud_args=None):
"""
network eval function. Get top1 and top5 ACC from classification.
@ -121,6 +156,12 @@ def test(cloud_args=None):
"""
args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
save_graphs=True)
if args.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
# init distributed
if args.is_distributed:
init()
@ -164,7 +205,8 @@ def test(cloud_args=None):
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(model))
network.add_flags_recursive(fp16=True)
if args.device_target == 'Ascend':
network.add_flags_recursive(fp16=True)
img_tot = 0
top1_correct = 0
@ -186,41 +228,9 @@ def test(cloud_args=None):
results = [[top1_correct], [top5_correct], [img_tot]]
args.logger.info('before results={}'.format(results))
if args.is_distributed:
model_md5 = model.replace('/', '')
tmp_dir = '../cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, args.rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(args.group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) \
or not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break
top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(args.group_size):
top1_correct_npy = '{}/top1_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top5_correct_npy = '{}/top5_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
img_tot_npy = '{}/img_tot_rank_{}_{}.npy'.format(tmp_dir, other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
results = generate_results(model, args.rank, args.group_size, top1_correct,
top5_correct, img_tot)
results = np.array(results)
else:
results = np.array(results)

@ -0,0 +1,63 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 4 ]
then
echo "Usage: sh run_distribute_eval_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
export DEVICE_NUM=$1
export RANK_SIZE=$1
# check checkpoint file
if [ ! -f $4 ]
then
echo "error: CHECKPOINT_PATH=$4 is not a file"
exit 1
fi
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../eval" ];
then
rm -rf ../eval
fi
mkdir ../eval
cd ../eval || exit
export CUDA_VISIBLE_DEVICES="$2"
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../eval.py \
--data_dir=$3 \
--device_target='GPU' \
--pretrained=$4 > eval.log 2>&1 &
else
python3 ${BASEPATH}/../eval.py \
--data_dir=$3 \
--device_target='GPU' \
--pretrained=$4 > eval.log 2>&1 &
fi

@ -0,0 +1,70 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 3 ]
then
echo "Usage: sh run_distribute_train_gpu.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [PRE_TRAINED](optional)"
exit 1
fi
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
export DEVICE_NUM=$1
export RANK_SIZE=$1
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit
export CUDA_VISIBLE_DEVICES="$2"
if [ -f $4 ] # pretrained ckpt
then
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--device_target='GPU' \
--pretrained=$4 > train.log 2>&1 &
else
python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--is_distributed=0 \
--device_target='GPU' \
--pretrained=$4 > train.log 2>&1 &
fi
else
if [ $1 -gt 1 ]
then
mpirun -n $1 --allow-run-as-root python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--device_target='GPU' > train.log 2>&1 &
else
python3 ${BASEPATH}/../train.py \
--data_dir=$3 \
--is_distributed=0 \
--device_target='GPU' > train.log 2>&1 &
fi
fi

@ -39,10 +39,6 @@ from src.lr_scheduler import MultiStepLR, CosineAnnealingLR
from src.utils.logging import get_logger
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target="Davinci", save_graphs=False, device_id=devid)
set_seed(1)
class BuildTrainNetwork(nn.Cell):
@ -124,6 +120,9 @@ def parse_args(cloud_args=None):
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
# platform
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config.image_size
@ -172,6 +171,13 @@ def train(cloud_args=None):
"""training process"""
args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False)
if args.device_target == 'Ascend':
devid = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=devid)
# init distributed
if args.is_distributed:
init()
@ -181,7 +187,7 @@ def train(cloud_args=None):
if args.is_dynamic_loss_scale == 1:
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
# select for master rank save ckpt or all rank save, compatiable for model parallel
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
@ -269,7 +275,13 @@ def train(cloud_args=None):
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
gradients_mean=True)
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
if args.device_target == 'Ascend':
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O3")
elif args.device_target == 'GPU':
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=loss_scale_manager, amp_level="O0")
else:
raise ValueError("Unsupported device target.")
# checkpoint save
progress_cb = ProgressMonitor(args)

Loading…
Cancel
Save