!12699 ResNext101 64x4d mindspore ECNU liyiming

From: @neoming
Reviewed-by: @oacjiewen
Signed-off-by:
pull/12699/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 64a93c2089

@ -0,0 +1,224 @@
# ResNext101-64x4d for MindSpore
本仓库提供了ResNeXt101-64x4d模型的训练脚本和超参配置以达到论文中的准确性。
## 模型概述
模型名称ResNeXt101
论文:`"Aggregated Residual Transformations for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`
这里提供的版本是ResNeXt101-64x4d
### 模型架构
ResNeXt是ResNet网络的改进版本比ResNet的网络多了块多了cardinality设置。ResNeXt101-64x4d的网络结构如下
| 网络层 | 输出 | 参数 |
| ---------- | ------- | ------------------------------------------- |
| conv1 | 112x112 | 7x7,64,stride 2 |
| maxpooline | 56x56 | 3x3,stride 2 |
| conv2 | 56x56 | [(1x1,64)->(3x3,64)->(1x1,256) C=64]x3 |
| conv3 | 28x28 | [(1x1,256)->(3x3,256)->(1x1,512) C=64]x4 |
| conv4 | 14x14 | [(1x1,512)->(3x3,512)->(1x1,1024) C=64]x23 |
| conv5 | 7x7 | [(1x1,1024)->(3x3,1024)->(1x1,2048) C=64]x3 |
| | 1x1 | average pool1000-d fcsoftmax |
### 默认设置
以下各节介绍ResNext50模型的默认配置和超参数。
#### 优化器
本模型使用Mindspore框架提供的Momentum优化器超参设置如下
- Momentum : 0.9
- Learning rate (LR) : 0.05
- LR schedule: cosine_annealing
- LR epochs: [30, 60, 90, 120]
- LR gamma: 0.1
- Batch size : 64
- Weight decay : 0.0001.
- Label smoothing = 0.1
- Eta_min: 0
- Warmup_epochs: 1
- Loss_scale: 1024
- 训练轮次:
- 150 epochs
#### 数据增强
本模型使用了以下数据增强:
- 对于训练脚本:
- RandomResizeCrop, scale=(0.08, 1.0), ratio=(0.75, 1.333)
- RandomHorizontalFlip, prob=0.5
- Normalize, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
- 对于验证(前向推理):
- Resize to (256, 256)
- CenterCrop to (224, 224)
- Normalize, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
## 设定
以下各节列出了开始训练ResNext101-64x4d模型的要求。
## 快速入门指南
目录说明代码参考了Modelzoo上的[ResNext50_for_MindSpore](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/resnext50)
```path
.
└─resnext101-64x4d-mindspore
├─README.md
├─scripts
├─run_standalone_train.sh # 启动Ascend单机训练单卡
├─run_distribute_train.sh # 启动Ascend分布式训练8卡
├─run_standalone_train_for_gpu.sh # 启动GPU单机训练单卡
├─run_distribute_train_for_gpu.sh # 启动GPU分布式训练8卡
└─run_eval.sh # 启动评估
├─src
├─backbone
├─_init_.py # 初始化
├─resnext.py # ResNeXt101骨干
├─utils
├─_init_.py # 初始化
├─cunstom_op.py # 网络操作
├─logging.py # 打印日志
├─optimizers_init_.py # 获取参数
├─sampler.py # 分布式采样器
├─var_init_.py # 计算增益值
├─_init_.py # 初始化
├─config.py # 参数配置
├─crossentropy.py # 交叉熵损失函数
├─dataset.py # 数据预处理
├─head.py # 常见头
├─image_classification.py # 获取ResNet
├─linear_warmup.py # 线性热身学习率
├─warmup_cosine_annealing.py # 每次迭代的学习率
├─warmup_step_lr.py # 热身迭代学习率
├─eval.py # 评估网络
├──train.py # 训练网络
├──mindspore_hub_conf.py # MindSpore Hub接口
```
### 1. 仓库克隆
```shell
git clone https://gitee.com/neoming/resnext101-64x4d-mindspore.git
cd resnext101-64x4d-mindspore/
```
### 2. 数据下载和预处理
1. 下载ImageNet数据集
2. 解压训练数据集和验证数据
3. 训练和验证图像分别位于train /和val /目录下。 一个文件夹中的所有图像都具有相同的标签。
### 3. 训练(单卡)
可以通过python脚本开始训练
```shell
python train.py --data_dir ~/imagenet/train/ --platform Ascend --is_distributed
```
或通过shell脚本开始训练
```shell
Ascend:
# 分布式训练示例8卡
bash scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
# 单机训练
bash scripts/run_standalone_train.sh DEVICE_ID DATA_PATH
GPU:
# 分布式训练示例8卡
bash scripts/run_distribute_train_for_gpu.sh DATA_PATH
# 单机训练
bash scripts/run_standalone_train_for_gpu.sh DEVICE_ID DATA_PATH
```
### 4. 测试
您可以通过python脚本开始验证
```shell
python eval.py --data_dir ~/imagenet/val/ --platform Ascend --pretrained resnext.ckpt
```
或通过shell脚本开始训练
```shell
# 评估
bash scripts/run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM
```
## 模型导出
```shell
python export.py --device_target [PLATFORM] --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT]
```
`EXPORT_FORMAT` 可选 ["AIR", "ONNX", "MINDIR"].
## 高级设置
### 超参设置
通过`src/config.py`文件进行设置下面是ImageNet单卡实验的设置
```python
"image_size": '224,224',
"num_classes": 1000,
"lr": 0.05,
"lr_scheduler": 'cosine_annealing',
"lr_epochs": '30,60,90,120',
"lr_gamma": 0.1,
"eta_min": 0,
"T_max": 150,
"max_epoch": 150,
"backbone": 'resnext101',
"warmup_epochs": 1,
"weight_decay": 0.0001,
"momentum": 0.9,
"is_dynamic_loss_scale": 0,
"loss_scale": 1024,
"label_smooth": 1,
"label_smooth_factor": 0.1,
"ckpt_interval": 1250,
"ckpt_path": 'outputs/',
"is_save_on_master": 1,
"rank": 0,
"group_size": 1
```
### 训练过程
训练的所有结果将存储在用--ckpt_path参数指定的目录中。
训练脚本将会存储:
- checkpoints.
- log.
## 性能
### 结果
通过运行训练脚本获得了以下结果。 要获得相同的结果,请遵循快速入门指南中的步骤。
#### 准确度
| **epochs** | Top1/Top5 |
| :--------: | :-----------: |
| 150 | 79.56%(TOP1)/94.68%(TOP5) |
#### 训练性能结果
| **NPUs** | train performance |
| :------: | :---------------: |
| 1 | 196.33image/sec |

File diff suppressed because it is too large Load Diff

@ -0,0 +1,47 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
resnext export mindir.
"""
import argparse
import numpy as np
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config
from src.image_classification import get_network
parser = argparse.ArgumentParser(description='checkpoint export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument('--width', type=int, default=224, help='input width')
parser.add_argument('--height', type=int, default=224, help='input height')
parser.add_argument("--file_name", type=str, default="resnext50", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
net = get_network(num_classes=config.num_classes, platform=args.device_target)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_shp = [args.batch_size, 3, args.height, args.width]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(net, input_array, file_name=args.file_name, file_format=args.file_format)

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$2
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export HCCL_CONNECT_TIMEOUT=600
echo "hccl connect time out has changed to 600 second"
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG$i
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../train.py \
--is_distribute \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &
cd ../
done

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
export RANK_SIZE=8
PATH_CHECKPOINT=""
if [ $# == 2 ]
then
PATH_CHECKPOINT=$2
fi
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--is_distribute \
--platform="GPU" \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
PLATFORM=Ascend
if [ $# == 4 ]
then
PLATFORM=$4
fi
python eval.py \
--pretrained=$PATH_CHECKPOINT \
--platform=$PLATFORM \
--data_dir=$DATA_DIR > log.txt 2>&1 &

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
python train.py \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
python train.py \
--pretrained=$PATH_CHECKPOINT \
--platform="GPU" \
--data_dir=$DATA_DIR > log.txt 2>&1 &

@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""resnext"""
from .resnext import *

File diff suppressed because it is too large Load Diff

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config"""
from easydict import EasyDict as ed
config = ed({
"image_size": '224,224',
"num_classes": 1000,
"lr": 0.4,
"lr_scheduler": 'cosine_annealing',
"lr_epochs": '30,60,90,120',
"lr_gamma": 0.1,
"eta_min": 0,
"T_max": 150,
"max_epoch": 150,
"warmup_epochs": 1,
"weight_decay": 0.0001,
"momentum": 0.9,
"is_dynamic_loss_scale": 0,
"loss_scale": 1024,
"label_smooth": 1,
"label_smooth_factor": 0.1,
"ckpt_interval": 5,
"ckpt_save_max": 5,
"ckpt_path": 'outputs/',
"is_save_on_master": 1,
# this two parameter is used for mindspore distributed configuration
"rank": 0,
"group_size": 1
})

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
define loss function for network.
"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""
the redefined loss function with SoftmaxCrossEntropyWithLogits.
"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

@ -0,0 +1,158 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset processing.
"""
import os
from mindspore.common import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as V_C
from PIL import Image, ImageFile
from src.utils.sampler import DistributedSampler
ImageFile.LOAD_TRUNCATED_IMAGES = True
class TxtDataset():
"""
create txt dataset.
Args:
Returns:
de_dataset.
"""
def __init__(self, root, txt_name):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
fin = open(txt_name, "r")
for line in fin:
img_name, label = line.strip().split(' ')
self.imgs.append(os.path.join(root, img_name))
self.labels.append(int(label))
fin.close()
def __getitem__(self, index):
img = Image.open(self.imgs[index]).convert('RGB')
return img, self.labels[index]
def __len__(self):
return len(self.imgs)
def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
mode='train',
input_mode='folder',
root='',
num_parallel_workers=None,
shuffle=None,
sampler=None,
class_indexing=None,
drop_remainder=True,
transform=None,
target_transform=None):
"""
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
are written into a textfile.
Args:
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
Or path of the textfile that contains every image's path of the dataset.
image_size (Union(int, sequence)): Size of the input images.
per_batch_size (int): the batch size of evey step during training.
max_epoch (int): the number of epochs.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
root (str): the images path for "input_mode="txt"". Default: " ".
num_parallel_workers (int): Number of workers to read the data. Default: None.
shuffle (bool): Whether or not to perform shuffle on the dataset
(default=None, performs shuffle).
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
class_indexing (dict): A str-to-int mapping from folder name to index
(default=None, the folder names will be sorted
alphabetically and each class will be given a
unique index starting from 0).
Examples:
>>> from src.dataset import classification_dataset
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
>>> data_dir = "/path/to/imagefolder_directory"
>>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4)
>>> # Path of the textfile that contains every image's path of the dataset.
>>> data_dir = "/path/to/dataset/images/train.txt"
>>> images_dir = "/path/to/dataset/images"
>>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4,
>>> input_mode="txt", root=images_dir)
"""
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if transform is None:
if mode == 'train':
transform_img = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5),
V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = [
V_C.Decode(),
V_C.Resize((256, 256)),
V_C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = transform
if target_transform is None:
transform_label = [C.TypeCast(mstype.int32)]
else:
transform_label = target_transform
if input_mode == 'folder':
de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=group_size, shard_id=rank)
else:
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset = de_dataset.map(operations=transform_img, input_columns="image",
num_parallel_workers=num_parallel_workers)
de_dataset = de_dataset.map(operations=transform_label, input_columns="label",
num_parallel_workers=num_parallel_workers)
columns_to_project = ["image", "label"]
de_dataset = de_dataset.project(columns=columns_to_project)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
de_dataset = de_dataset.repeat(max_epoch)
return de_dataset

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
common architecture.
"""
import mindspore.nn as nn
from src.utils.cunstom_op import GlobalAvgPooling
__all__ = ['CommonHead']
class CommonHead(nn.Cell):
"""
common architecture definition.
Args:
num_classes (int): Number of classes.
out_channels (int): Output channels.
Returns:
Tensor, output tensor.
"""
def __init__(self, num_classes, out_channels):
super(CommonHead, self).__init__()
self.avgpool = GlobalAvgPooling()
self.fc = nn.Dense(out_channels, num_classes, has_bias=True).add_flags_recursive(fp16=True)
def construct(self, x):
x = self.avgpool(x)
x = self.fc(x)
return x

@ -0,0 +1,98 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image classifiation.
"""
import math
import mindspore.nn as nn
from mindspore.common import initializer as init
import src.backbone as backbones
import src.head as heads
from src.utils.var_init import default_recurisive_init, KaimingNormal
class ImageClassificationNetwork(nn.Cell):
"""
architecture of image classification network.
Args:
Returns:
Tensor, output tensor.
"""
def __init__(self, backbone, head, include_top=True, activation="None"):
super(ImageClassificationNetwork, self).__init__()
self.backbone = backbone
self.include_top = include_top
self.need_activation = False
if self.include_top:
self.head = head
if activation != "None":
self.need_activation = True
if activation == "Sigmoid":
self.activation = P.Sigmoid()
elif activation == "Softmax":
self.activation = P.Softmax()
else:
raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].")
def construct(self, x):
x = self.backbone(x)
if self.include_top:
x = self.head(x)
if self.need_activation:
x = self.activation(x)
return x
class ResNeXt(ImageClassificationNetwork):
"""
ResNeXt architecture.
Args:
backbone_name (string): backbone.
num_classes (int): number of classes, Default is 1000.
Returns:
Resnet.
"""
def __init__(self, backbone_name, num_classes=1000, platform="Ascend", include_top=True, activation="None"):
self.backbone_name = backbone_name
backbone = backbones.__dict__[self.backbone_name](platform=platform)
out_channels = backbone.get_out_channels()
head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels)
super(ResNeXt, self).__init__(backbone, head, include_top, activation)
default_recurisive_init(self)
for cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out', nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
for cell in self.cells_and_names():
if isinstance(cell, backbones.resnet.Bottleneck):
cell.bn3.gamma.set_data(init.initializer('zeros', cell.bn3.gamma.shape))
elif isinstance(cell, backbones.resnet.BasicBlock):
cell.bn2.gamma.set_data(init.initializer('zeros', cell.bn2.gamma.shape))
def get_network(**kwargs):
return ResNeXt('resnext101', **kwargs)

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
learning rate generator.
"""
import math
from collections import Counter
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""
Applies liner Increasing to generate learning rate array in warmup stage.
Args:
current_step(int): current step in warmup stage.
warmup_steps(int): all steps in warmup stage.
base_lr(float): init learning rate.
init_lr(float): end learning rate
Returns:
float, learning rate.
"""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""
Applies cosine decay to generate learning rate array with warmup.
Args:
lr(float): init learning rate
steps_per_epoch(int): steps of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epoch of training
T_max(int): max epoch in decay.
eta_min(float): end learning rate
Returns:
np.array, learning rate array.
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""
Applies step decay to generate learning rate array with warmup.
Args:
lr(float): init learning rate
lr_epochs(list): learning rate decay epoches list
steps_per_epoch(int): steps of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epoch of training
gamma(float): attenuation constants.
Returns:
np.array, learning rate array.
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def get_lr(args):
"""generate learning rate array."""
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
return lr

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto mixed precision."""
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
class OutputTo(nn.Cell):
"Cast cell output back to float16 or float32"
def __init__(self, op, to_type=mstype.float16):
super(OutputTo, self).__init__(auto_prefix=False)
self._op = op
validator.check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None)
self.to_type = to_type
def construct(self, x):
return F.cast(self._op(x), self.to_type)
def auto_mixed_precision(network):
"""Do keep batchnorm fp32."""
cells = network.name_cells()
change = False
network.to_float(mstype.float16)
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif name == 'fc':
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
change = True
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
change = True
else:
auto_mixed_precision(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network operations
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class GlobalAvgPooling(nn.Cell):
"""
global average pooling feature map.
Args:
mean (tuple): means for each channel.
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class SEBlock(nn.Cell):
"""
squeeze and excitation block.
Args:
channel (int): number of feature maps.
reduction (int): weight.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = GlobalAvgPooling()
self.fc1 = nn.Dense(channel, channel // reduction)
self.relu = P.ReLU()
self.fc2 = nn.Dense(channel // reduction, channel)
self.sigmoid = P.Sigmoid()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.sum = P.Sum()
self.cast = P.Cast()
def construct(self, x):
b, c = self.shape(x)
y = self.avg_pool(x)
y = self.reshape(y, (b, c))
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
y = self.reshape(y, (b, c, 1, 1))
return x * y
class GroupConv(nn.Cell):
"""
group convolution operation.
Args:
in_channels (int): Input channels of feature map.
out_channels (int): Output channels of feature map.
kernel_size (int): Size of convolution kernel.
stride (int): Stride size for the group convolution layer.
Returns:
tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
super(GroupConv, self).__init__()
assert in_channels % groups == 0 and out_channels % groups == 0
self.groups = groups
self.convs = nn.CellList()
self.op_split = P.Split(axis=1, output_num=self.groups)
self.op_concat = P.Concat(axis=1)
self.cast = P.Cast()
for _ in range(groups):
self.convs.append(nn.Conv2d(in_channels//groups, out_channels//groups,
kernel_size=kernel_size, stride=stride, has_bias=has_bias,
padding=pad, pad_mode=pad_mode, group=1))
def construct(self, x):
features = self.op_split(x)
outputs = ()
for i in range(self.groups):
outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),)
out = self.op_concat(outputs)
return out

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import logging
import os
import sys
from datetime import datetime
class LOGGER(logging.Logger):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""set up log file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER("mindversion", rank)
logger.setup_logging_file(path, rank)
return logger

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
optimizer parameters.
"""
def get_param_groups(network):
"""get param groups"""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
choose samples from the dataset
"""
import math
import numpy as np
class DistributedSampler():
"""
sampling the dataset.
Args:
Returns:
num_samples, number of samples.
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_length = len(self.dataset)
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset_length)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples

@ -0,0 +1,228 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import os
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore.common import initializer as init
from mindspore.train.serialization import load_checkpoint, load_param_into_net
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
pass
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def default_recurisive_init(custom_cell):
"""default_recurisive_init"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
def load_pretrain_model(ckpt_file, network, args):
"""load pretrain model."""
if os.path.isfile(ckpt_file):
param_dict = load_checkpoint(ckpt_file)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(ckpt_file))

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save