From 0a774b4aed320e2704db774f3d7938e9914be046 Mon Sep 17 00:00:00 2001 From: panpanrui <8315094+panpanrui@user.noreply.gitee.com> Date: Mon, 22 Feb 2021 11:03:00 +0800 Subject: [PATCH] =?UTF-8?q?resnet152=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_zoo/official/cv/resnet152/README-CN.md | 203 +++++++++ model_zoo/official/cv/resnet152/eval.py | 65 +++ .../resnet152/scripts/run_distribute_train.sh | 89 ++++ .../official/cv/resnet152/scripts/run_eval.sh | 64 +++ .../resnet152/scripts/run_standalone_train.sh | 77 ++++ .../cv/resnet152/src/CrossEntropySmooth.py | 38 ++ model_zoo/official/cv/resnet152/src/config.py | 124 ++++++ .../official/cv/resnet152/src/dataset.py | 300 +++++++++++++ .../official/cv/resnet152/src/lr_generator.py | 199 +++++++++ model_zoo/official/cv/resnet152/src/resnet.py | 407 ++++++++++++++++++ model_zoo/official/cv/resnet152/train.py | 150 +++++++ 11 files changed, 1716 insertions(+) create mode 100644 model_zoo/official/cv/resnet152/README-CN.md create mode 100644 model_zoo/official/cv/resnet152/eval.py create mode 100644 model_zoo/official/cv/resnet152/scripts/run_distribute_train.sh create mode 100644 model_zoo/official/cv/resnet152/scripts/run_eval.sh create mode 100644 model_zoo/official/cv/resnet152/scripts/run_standalone_train.sh create mode 100644 model_zoo/official/cv/resnet152/src/CrossEntropySmooth.py create mode 100644 model_zoo/official/cv/resnet152/src/config.py create mode 100644 model_zoo/official/cv/resnet152/src/dataset.py create mode 100644 model_zoo/official/cv/resnet152/src/lr_generator.py create mode 100644 model_zoo/official/cv/resnet152/src/resnet.py create mode 100644 model_zoo/official/cv/resnet152/train.py diff --git a/model_zoo/official/cv/resnet152/README-CN.md b/model_zoo/official/cv/resnet152/README-CN.md new file mode 100644 index 0000000000..75e7f9e2f9 --- /dev/null +++ b/model_zoo/official/cv/resnet152/README-CN.md @@ -0,0 +1,203 @@ + +# Resnet152描述 + +## 概述 + +ResNet系列模型是在2015年提出的,通过ResNet单元,成功训练152层神经网络,一举在ILSVRC2015比赛中取得冠军。该网络创新性的提出了残差结构,通过堆叠多个残差结构从而构建了ResNet网络。传统的卷积网络或全连接网络或多或少存在信息丢失的问题,还会造成梯度消失或爆炸,导致深度网络训练失败,ResNet则在一定程度上解决了这个问题。通过将输入信息传递给输出,确保信息完整性。整个网络只需要学习输入和输出的差异部分,简化了学习目标和难度。正因如此,ResNet十分受欢迎,甚至可以直接用于ConceptNet网络。 + +如下为MindSpore使用ImageNet2012数据集对ResNet152进行训练的示例。 + +## 论文 + +1. [论文](https://arxiv.org/pdf/1512.03385.pdf): Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun."Deep Residual Learning for Image Recognition" + +# 模型架构 + +ResNet152的总体网络架构如下:[链接](https://arxiv.org/pdf/1512.03385.pdf) + +# 数据集 + +使用的数据集:[ImageNet2012](http://www.image-net.org/) + +- 数据集大小:共1000个类、224*224彩色图像 + - 训练集:共1,281,167张图像 + - 测试集:共50,000张图像 +- 数据格式:JPEG + - 注:数据在dataset.py中处理。 +- 下载数据集,目录结构如下: + +```text +└─dataset + ├─ilsvrc # 训练数据集 + └─validation_preprocess # 评估数据集 +``` + +# 环境要求 + +- 硬件 + - 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com,审核通过即可获得资源。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install/en) +- 如需查看详情,请参见如下资源: + - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) + +# 快速入门 + +通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估: + +- Ascend处理器环境运行 + +```Shell +# 分布式训练 +用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) + +# 单机训练 +用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) + +# 运行评估示例 +用法:sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` + +# 脚本说明 + +## 脚本及样例代码 + +```text +└──resnet + ├── README.md + ├── scripts + ├── run_distribute_train.sh # 启动Ascend分布式训练(8卡) + ├── run_eval.sh # 启动Ascend评估 + └── run_standalone_train.sh # 启动Ascend单机训练(单卡) + ├── src + ├── config.py # 参数配置 + ├── dataset.py # 数据预处理 + ├── CrossEntropySmooth.py # ImageNet2012数据集的损失定义 + ├── lr_generator.py # 生成每个步骤的学习率 + └── resnet.py # ResNet骨干网络,包括ResNet50、ResNet101、SE-ResNet50和Resnet152 + ├── eval.py # 评估网络 + └── train.py # 训练网络 +``` + +# 脚本参数 + +在config.py中可以同时配置训练参数和评估参数。 + +- 配置ResNet152和ImageNet2012数据集。 + +```Python +"class_num":1001, # 数据集类数 +"batch_size":32, # 输入张量的批次大小 +"loss_scale":1024, # 损失等级 +"momentum":0.9, # 动量优化器 +"weight_decay":1e-4, # 权重衰减 +"epoch_size":140, # 训练周期大小 +"save_checkpoint":True, # 是否保存检查点 +"save_checkpoint_epochs":5, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存 +"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点 +"save_checkpoint_path":"./", # 检查点相对于执行路径的保存路径 +"warmup_epochs":0, # 热身周期数 +"lr_decay_mode":"steps", # 用于生成学习率的衰减模式 +"use_label_smooth":True, # 标签平滑 +"label_smooth_factor":0.1, # 标签平滑因子 +"lr":0.1 # 基础学习率 +"lr_end":0.0001, # 最终学习率 +``` + +# 训练过程 + +## 用法 + +## Ascend处理器环境运行 + +```Shell +# 分布式训练 +用法:sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) + +# 单机训练 +用法:sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](可选) + +``` + +分布式训练需要提前创建JSON格式的HCCL配置文件。 + +具体操作,参见[hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools)中的说明。 + +训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。 + +## 结果 + +- 使用ImageNet2012数据集训练ResNet50 + +```text +# 分布式训练结果(8P) +epoch: 1 step: 5004, loss is 4.184874 +epoch: 2 step: 5004, loss is 4.013571 +epoch: 3 step: 5004, loss is 3.695777 +epoch: 4 step: 5004, loss is 3.3244863 +epoch: 5 step: 5004, loss is 3.4899402 +... +``` + +# 评估过程 + +## 用法 + +### Ascend处理器环境运行 + +```Shell +# 评估 +Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` + +```Shell +# 评估示例 +sh run_eval.sh /data/dataset/ImageNet/imagenet_original Resnet152-140_5004.ckpt +``` + +训练过程中可以生成检查点。 + +## 结果 + +评估结果保存在示例路径中,文件夹名为“eval”。您可在此路径下的日志找到如下结果: + +- 使用ImageNet2012数据集评估ResNet152 + +```text +result: {'top_5_accuracy': 0.9438420294494239, 'top_1_accuracy': 0.78817221518} ckpt= resnet152-140_5004.ckpt +``` + +# 模型描述 + +## 性能 + +### 评估性能 + +#### ImageNet2012上的ResNet152 + +| 参数 | Ascend 910 | +|---|---| +| 模型版本 | ResNet152 | +| 资源 | Ascend 910;CPU:2.60GHz,192核;内存:755G | +| 上传日期 |2021-02-10 ; | +| MindSpore版本 | 1.0.1 | +| 数据集 | ImageNet2012 | +| 训练参数 | epoch=140, steps per epoch=5004, batch_size = 32 | +| 优化器 | Momentum | +| 损失函数 |Softmax交叉熵 | +| 输出 | 概率 | +| 损失 | 1.7375104 | +|速度|47.47毫秒/步(8卡) | +|总时长 | 577分钟 | +|参数(M) | 60.19 | +| 微调检查点 | 462M(.ckpt文件) | +| 脚本 | [链接](https://gitee.com/panpanrui/mindspore/tree/master/model_zoo/official/cv/resnet152) | + +# 随机情况说明 + +dataset.py中设置了“create_dataset”函数内的种子,同时还使用了train.py中的随机种子。 + +# ModelZoo主页 + +请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 \ No newline at end of file diff --git a/model_zoo/official/cv/resnet152/eval.py b/model_zoo/official/cv/resnet152/eval.py new file mode 100644 index 0000000000..f239c03cc0 --- /dev/null +++ b/model_zoo/official/cv/resnet152/eval.py @@ -0,0 +1,65 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""eval resnet.""" +import argparse +from mindspore import context +from mindspore.common import set_seed +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from src.CrossEntropySmooth import CrossEntropySmooth +from src.resnet import resnet152 as resnet +from src.config import config5 as config +from src.dataset import create_dataset2 as create_dataset + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--data_url', type=str, default=None, help='Dataset path') +args_opt = parser.parse_args() + +set_seed(1) + +if __name__ == '__main__': + target = "Ascend" + + # init context + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + + # create dataset + local_data_path = args_opt.data_url + print('Download data.') + dataset = create_dataset(dataset_path=local_data_path, do_train=False, batch_size=config.batch_size, + target=target) + step_size = dataset.get_dataset_size() + + # define net + net = resnet(class_num=config.class_num) + + ckpt_name = args_opt.checkpoint_path + param_dict = load_checkpoint(ckpt_name) + load_param_into_net(net, param_dict) + net.set_train(False) + + # define loss, model + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(sparse=True, reduction='mean', + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + + # define model + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + # eval model + res = model.eval(dataset) + print("result:", res, "ckpt=", ckpt_name) diff --git a/model_zoo/official/cv/resnet152/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet152/scripts/run_distribute_train.sh new file mode 100644 index 0000000000..dc0afaf636 --- /dev/null +++ b/model_zoo/official/cv/resnet152/scripts/run_distribute_train.sh @@ -0,0 +1,89 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_distribute_train.sh RANK_TABLE_FILE DATA_PATH PRETRAINED_CKPT_PATH](optional)" +echo "For example: bash run_distribute_train.sh hccl_8p_01234567_127.0.0.1.json /path/dataset" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ $# == 3 ] +then + PATH3=$(get_real_path $5) +fi + +if [ ! -f $PATH1 ] +then + echo "error: RANK_TABLE_FILE=$PATH1 is not a file" +exit 1 +fi + +if [ ! -d $PATH2 ] +then + echo "error: DATA_PATH=$PATH2 is not a directory" +exit 1 +fi + +if [ $# == 3 ] && [ ! -f $PATH3 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE=$PATH1 + +DATA_PATH=$2 +export DATA_PATH=${DATA_PATH} + +for((i=0;i<${RANK_SIZE};i++)) +do + rm -rf device$i + mkdir device$i + cp ../*.py ./device$i + cp *.sh ./device$i + cp -r ../src ./device$i + cd ./device$i + export DEVICE_ID=$i + export RANK_ID=$i + echo "start training for device $i" + env > env$i.log + + if [ $# == 2 ] + then + python train.py --run_distribute=True --data_url=$PATH2 &> train.log & + fi + + if [ $# == 3 ] + then + python train.py --run_distribute=True --data_url=$PATH2 --pre_trained=$PATH3 &> train.log & + fi + + cd ../ +done diff --git a/model_zoo/official/cv/resnet152/scripts/run_eval.sh b/model_zoo/official/cv/resnet152/scripts/run_eval.sh new file mode 100644 index 0000000000..1cc5c35828 --- /dev/null +++ b/model_zoo/official/cv/resnet152/scripts/run_eval.sh @@ -0,0 +1,64 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_eval.sh DATA_PATH CHECKPOINT_PATH " +echo "For example: bash run.sh /path/dataset Resnet152-140_5004.ckpt" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $PATH2 ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval +env > env.log +echo "start evaluation for device $DEVICE_ID" +python eval.py --data_url=$PATH1 --checkpoint_path=$PATH2 &> eval.log & +cd .. \ No newline at end of file diff --git a/model_zoo/official/cv/resnet152/scripts/run_standalone_train.sh b/model_zoo/official/cv/resnet152/scripts/run_standalone_train.sh new file mode 100644 index 0000000000..7f09d5bd24 --- /dev/null +++ b/model_zoo/official/cv/resnet152/scripts/run_standalone_train.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_standalone_train.sh DATA_PATH PRETRAINED_CKPT_PATH(optional)" +echo "For example: bash run_standalone_train.sh /path/dataset" +echo "It is better to use the absolute path." +echo "==============================================================================================================" + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +if [ $# == 2 ] +then + PATH2=$(get_real_path $2) +fi + +if [ ! -d $PATH1 ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ $# == 2 ] && [ ! -f $PATH2 ] +then + echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=6 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train +echo "start training for device $DEVICE_ID" +env > env.log +if [ $# == 1 ] +then + python train.py --data_url=$PATH1 &> train.log & +fi + +if [ $# == 2 ] +then + python train.py --data_url=$PATH1 --pre_trained=$PATH2 &> train.log & +fi +cd .. + + diff --git a/model_zoo/official/cv/resnet152/src/CrossEntropySmooth.py b/model_zoo/official/cv/resnet152/src/CrossEntropySmooth.py new file mode 100644 index 0000000000..6d63b66694 --- /dev/null +++ b/model_zoo/official/cv/resnet152/src/CrossEntropySmooth.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +class CrossEntropySmooth(_Loss): + """CrossEntropy""" + def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): + super(CrossEntropySmooth, self).__init__() + self.onehot = P.OneHot() + self.sparse = sparse + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) + + def construct(self, logit, label): + if self.sparse: + label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, label) + return loss diff --git a/model_zoo/official/cv/resnet152/src/config.py b/model_zoo/official/cv/resnet152/src/config.py new file mode 100644 index 0000000000..11c3d99baa --- /dev/null +++ b/model_zoo/official/cv/resnet152/src/config.py @@ -0,0 +1,124 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in train.py and eval.py +""" +from easydict import EasyDict as ed + +# config for resent50, cifar10 +config1 = ed({ + "class_num": 10, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 90, + "pretrain_epoch_size": 0, + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 5, + "lr_decay_mode": "poly", + "lr_init": 0.01, + "lr_end": 0.00001, + "lr_max": 0.1 +}) + +# config for resnet50, imagenet2012 +config2 = ed({ + "class_num": 1001, + "batch_size": 256, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 90, + "pretrain_epoch_size": 0, + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "linear", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0, + "lr_max": 0.8, + "lr_end": 0.0 +}) + +# config for resent101, imagenet2012 +config3 = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 120, + "pretrain_epoch_size": 0, + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr": 0.1 +}) + +# config for se-resnet50, imagenet2012 +config4 = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 28, + "train_epoch_size": 24, + "pretrain_epoch_size": 0, + "save_checkpoint": True, + "save_checkpoint_epochs": 4, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 3, + "lr_decay_mode": "cosine", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0.0, + "lr_max": 0.3, + "lr_end": 0.0001 +}) + +# config for resnet152, imagenet2012 +config5 = ed({ + "class_num": 1001, + "batch_size": 32, + "loss_scale": 1024, + "momentum": 0.9, + "weight_decay": 1e-4, + "epoch_size": 140, + "save_checkpoint": True, + "save_checkpoint_epochs": 5, + "keep_checkpoint_max": 10, + "save_checkpoint_path": "./", + "warmup_epochs": 0, + "lr_decay_mode": "steps", + "use_label_smooth": True, + "label_smooth_factor": 0.1, + "lr_init": 0.0, + "lr_max": 0.1, + "lr_end": 0.0001 +}) diff --git a/model_zoo/official/cv/resnet152/src/dataset.py b/model_zoo/official/cv/resnet152/src/dataset.py new file mode 100644 index 0000000000..020ecb5d23 --- /dev/null +++ b/model_zoo/official/cv/resnet152/src/dataset.py @@ -0,0 +1,300 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +create train or eval dataset. +""" +import os +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.vision.c_transforms as C +import mindspore.dataset.transforms.c_transforms as C2 +from mindspore.communication.management import init, get_rank, get_group_size + + +def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False): + """ + create a train or evaluate cifar10 dataset for resnet50 + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + distribute(bool): data for distribute or not. Default: False + + Returns: + dataset + """ + if not do_train: + dataset_path = os.path.join(dataset_path, 'eval') + else: + dataset_path = os.path.join(dataset_path, 'train') + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + if distribute: + init() + rank_id = get_rank() + device_num = get_group_size() + else: + device_num = 1 + if device_num == 1: + ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + # define map operations + trans = [] + if do_train: + trans += [ + C.RandomCrop((32, 32), (4, 4, 4, 4)), + C.RandomHorizontalFlip(prob=0.5) + ] + + trans += [ + C.Resize((224, 224)), + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + distribute(bool): data for distribute or not. Default: False + + Returns: + dataset + """ + if not do_train: + dataset_path = os.path.join(dataset_path, 'val') + else: + dataset_path = os.path.join(dataset_path, 'train') + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + if distribute: + init() + rank_id = get_rank() + device_num = get_group_size() + else: + device_num = 1 + + if device_num == 1: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + + image_size = 224 + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) + ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False): + """ + create a train or eval imagenet2012 dataset for resnet101 + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + distribute(bool): data for distribute or not. Default: False + + Returns: + dataset + """ + if not do_train: + dataset_path = os.path.join(dataset_path, 'val') + else: + dataset_path = os.path.join(dataset_path, 'train') + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + if distribute: + init() + rank_id = get_rank() + device_num = get_group_size() + else: + device_num = 1 + rank_id = 1 + if device_num == 1: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) + else: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, + num_shards=device_num, shard_id=rank_id) + image_size = 224 + mean = [0.475 * 255, 0.451 * 255, 0.392 * 255] + std = [0.275 * 255, 0.267 * 255, 0.278 * 255] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(rank_id / (rank_id + 1)), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(image_size), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8) + ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend", distribute=False): + """ + create a train or eval imagenet2012 dataset for se-resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + distribute(bool): data for distribute or not. Default: False + + Returns: + dataset + """ + if target == "Ascend": + device_num, rank_id = _get_rank_info() + else: + if distribute: + init() + rank_id = get_rank() + device_num = get_group_size() + else: + device_num = 1 + if device_num == 1: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True) + else: + ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=12, shuffle=True, + num_shards=device_num, shard_id=rank_id) + image_size = 224 + mean = [123.68, 116.78, 103.94] + std = [1.0, 1.0, 1.0] + + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + C.RandomHorizontalFlip(prob=0.5), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + else: + trans = [ + C.Decode(), + C.Resize(292), + C.CenterCrop(256), + C.Normalize(mean=mean, std=std), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=12) + ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=12) + + # apply batch operations + ds = ds.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id diff --git a/model_zoo/official/cv/resnet152/src/lr_generator.py b/model_zoo/official/cv/resnet152/src/lr_generator.py new file mode 100644 index 0000000000..92fa4adba8 --- /dev/null +++ b/model_zoo/official/cv/resnet152/src/lr_generator.py @@ -0,0 +1,199 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate generator""" +import math +import numpy as np + +def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps): + """ + Applies three steps decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + decay_epoch_index = [0.2 * total_steps, 0.5 * total_steps, 0.7 * total_steps, 0.9 * total_steps] + lr_each_step = [] + for i in range(total_steps): + if i < decay_epoch_index[0]: + lr = lr_max + elif i < decay_epoch_index[1]: + lr = lr_max * 0.1 + elif i < decay_epoch_index[2]: + lr = lr_max * 0.01 + elif i < decay_epoch_index[3]: + lr = lr_max * 0.001 + else: + lr = 0.00005 + lr_each_step.append(lr) + return lr_each_step + +def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): + """ + Applies polynomial decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + lr_each_step = [] + if warmup_steps != 0: + inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) + else: + inc_each_step = 0 + for i in range(total_steps): + if i < warmup_steps: + lr = float(lr_init) + inc_each_step * float(i) + else: + base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) + lr = float(lr_max) * base * base + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + return lr_each_step + +def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): + """ + Applies cosine decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + decay_steps = total_steps - warmup_steps + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) + lr = float(lr_init) + lr_inc * (i + 1) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = lr_max * decayed + lr_each_step.append(lr) + return lr_each_step + +def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): + """ + Applies liner decay to generate learning rate array. + + Args: + lr_init(float): init learning rate. + lr_end(float): end learning rate + lr_max(float): max learning rate. + total_steps(int): all steps in training. + warmup_steps(int): all steps in warmup epochs. + + Returns: + np.array, learning rate array. + """ + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) + lr_each_step.append(lr) + return lr_each_step + +def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(int): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default) + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = int(steps_per_epoch * total_epochs) + # warmup_steps = steps_per_epoch * warmup_epochs + warmup_steps = warmup_epochs + + if lr_decay_mode == 'steps': + lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps) + elif lr_decay_mode == 'poly': + lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps) + elif lr_decay_mode == 'cosine': + lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps) + else: + lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + return lr_each_step + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0): + """ + generate learning rate array with cosine + + Args: + lr(float): base learning rate + steps_per_epoch(int): steps size of one epoch + warmup_epochs(int): number of warmup epochs + max_epoch(int): total epochs of training + global_step(int): the current start index of lr array + Returns: + np.array, learning rate array + """ + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + decay_steps = total_steps - warmup_steps + + lr_each_step = [] + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + linear_decay = (total_steps - i) / decay_steps + cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) + decayed = linear_decay * cosine_decay + 0.00001 + lr = base_lr * decayed + lr_each_step.append(lr) + + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[global_step:] + return learning_rate diff --git a/model_zoo/official/cv/resnet152/src/resnet.py b/model_zoo/official/cv/resnet152/src/resnet.py new file mode 100644 index 0000000000..237bc5b32e --- /dev/null +++ b/model_zoo/official/cv/resnet152/src/resnet.py @@ -0,0 +1,407 @@ +"""ResNet""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from scipy.stats import truncnorm + + +def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): + fan_in = in_channel * kernel_size * kernel_size + scale = 1.0 + scale /= max(1., fan_in) + stddev = (scale ** 0.5) + mu, sigma = 0, stddev + weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) + weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) + return Tensor(weight, dtype=mstype.float32) + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) + else: + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) + else: + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) + else: + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel, use_se=False): + if use_se: + weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel) + weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32) + else: + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + use_se (bool): enable SE-ResNet50 net. Default: False. + se_block(bool): use se block in SE-ResNet50 net. Default: False. + + Returns: + Tensor, output tensor. + + Examples: + # >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1, + use_se=False, + se_block=False): + super(ResidualBlock, self).__init__() + self.stride = stride + self.use_se = use_se + self.se_block = se_block + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) + self.bn1 = _bn(channel) + if self.use_se and self.stride != 1: + self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), + nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) + else: + self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) + self.bn3 = _bn_last(out_channel) + if self.se_block: + self.se_global_pool = P.ReduceMean(keep_dims=False) + self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se) + self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se) + self.se_sigmoid = nn.Sigmoid() + self.se_mul = P.Mul() + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + + if self.down_sample: + if self.use_se: + if stride == 1: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, + stride, use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), + _conv1x1(in_channel, out_channel, 1, + use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, + use_se=self.use_se), _bn(out_channel)]) + self.add = P.TensorAdd() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + if self.use_se and self.stride != 1: + out = self.e2(out) + else: + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.se_block: + out_se = out + out = self.se_global_pool(out, (2, 3)) + out = self.se_dense_0(out) + out = self.relu(out) + out = self.se_dense_1(out) + out = self.se_sigmoid(out) + out = F.reshape(out, F.shape(out) + (1, 1)) + out = self.se_mul(out, out_se) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + use_se (bool): enable SE-ResNet50 net. Default: False. + # se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False. + Returns: + Tensor, output tensor. + + Examples: + # >>> ResNet(ResidualBlock, + # >>> [3, 4, 6, 3], + # >>> [64, 256, 512, 1024], + # >>> [256, 512, 1024, 2048], + # >>> [1, 2, 2, 2], + # >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes, + use_se=False): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + self.use_se = use_se + self.se_block = False + if self.use_se: + self.se_block = True + + if self.use_se: + self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) + self.bn1_0 = _bn(32) + self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) + self.bn1_1 = _bn(32) + self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) + else: + self.conv1 = _conv7x7(3, 64, stride=2) # (224, 224, 3) --> (112, 112, 64) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0], + use_se=self.use_se) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1], + use_se=self.use_se) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2], + use_se=self.use_se, + se_block=self.se_block) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3], + use_se=self.use_se, + se_block=self.se_block) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + se_block(bool): use se block in SE-ResNet50 net. Default: False. + Returns: + SequentialCell, the output layer. + + Examples: + # >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) + layers.append(resnet_block) + if se_block: + for _ in range(1, layer_num - 1): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) + layers.append(resnet_block) + else: + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + def construct(self, x): + if self.use_se: + x = self.conv1_0(x) + x = self.bn1_0(x) + x = self.relu(x) + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu(x) + x = self.conv1_2(x) + else: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + + return out + + +def resnet50(class_num=10): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + # >>> net = resnet50(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + + +def se_resnet50(class_num=1001): + """ + Get SE-ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of SE-ResNet50 neural network. + + Examples: + # >>> net = se-resnet50(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num, + use_se=True) + + +def resnet101(class_num=1001): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + # >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + + +def resnet152(class_num=1001): + """ + Get ResNet152 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet152 neural network. + + Examples: + # >>> net = resnet152(1001) + """ + return ResNet(ResidualBlock, + [3, 8, 36, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/model_zoo/official/cv/resnet152/train.py b/model_zoo/official/cv/resnet152/train.py new file mode 100644 index 0000000000..67df19e270 --- /dev/null +++ b/model_zoo/official/cv/resnet152/train.py @@ -0,0 +1,150 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train resnet.""" +import os +import argparse +import ast + +from mindspore import context +from mindspore import Tensor +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.model import Model +from mindspore.context import ParallelMode +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import FixedLossScaleManager +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.communication.management import init, get_rank +from mindspore.common import set_seed +import mindspore.nn as nn +import mindspore.common.initializer as weight_init +from src.lr_generator import get_lr +from src.CrossEntropySmooth import CrossEntropySmooth +from src.resnet import resnet152 as resnet +from src.config import config5 as config +from src.dataset import create_dataset2 as create_dataset # imagenet2012 + +parser = argparse.ArgumentParser(description='Image classification--resnet152') +parser.add_argument('--data_url', type=str, default=None, help='Dataset path') +parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute') +parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') +parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') +parser.add_argument('--is_save_on_master', type=ast.literal_eval, default=True, help='save ckpt on master or all rank') +args_opt = parser.parse_args() + +set_seed(1) + +if __name__ == '__main__': + ckpt_save_dir = config.save_checkpoint_path + + # init context + print(args_opt.run_distribute) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) + + if args_opt.run_distribute: + device_id = int(os.getenv('DEVICE_ID')) + rank_size = int(os.environ.get("RANK_SIZE", 1)) + print(rank_size) + device_num = rank_size + context.set_context(device_id=device_id, enable_auto_mixed_precision=True) + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, all_reduce_fusion_config=[180, 313]) + init() + args_opt.rank = get_rank() + print(args_opt.rank) + + # select for master rank save ckpt or all rank save, compatible for model parallel + args_opt.rank_save_ckpt_flag = 0 + if args_opt.is_save_on_master: + if args_opt.rank == 0: + args_opt.rank_save_ckpt_flag = 1 + else: + args_opt.rank_save_ckpt_flag = 1 + local_data_path = args_opt.data_url + + local_data_path = args_opt.data_url + print('Download data:') + + # create dataset + dataset = create_dataset(dataset_path=local_data_path, do_train=True, repeat_num=1, + batch_size=config.batch_size, target="Ascend", distribute=args_opt.run_distribute) + + step_size = dataset.get_dataset_size() + print("step"+str(step_size)) + + # define net + net = resnet(class_num=config.class_num) + + # init weight + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + else: + for _, cell in net.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(), + cell.weight.shape, + cell.weight.dtype)) + if isinstance(cell, nn.Dense): + cell.weight.set_data(weight_init.initializer(weight_init.HeNormal(), + cell.weight.shape, + cell.weight.dtype)) + + # init lr + lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, + warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size, + lr_decay_mode=config.lr_decay_mode) + lr = Tensor(lr) + + # define opt + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) + + # define loss, model + if not config.use_label_smooth: + config.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + + loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) + model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, + metrics={'top_1_accuracy', 'top_5_accuracy'}, + amp_level="O3", keep_batchnorm_fp32=False) + + # define callbacks + time_cb = TimeMonitor(data_size=step_size) + loss_cb = LossMonitor() + cb = [time_cb, loss_cb] + if config.save_checkpoint: + if args_opt.rank_save_ckpt_flag: + config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="resnet152", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + + # train model + dataset_sink_mode = True + print(dataset.get_dataset_size()) + model.train(config.epoch_size, dataset, callbacks=cb, + sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)