!8992 Add mobilenetv1

From: @c_34
Reviewed-by: 
Signed-off-by:
pull/8992/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 893c2cd772

@ -0,0 +1,187 @@
# Mobilenet_V1
- [Mobilenet_V1](#mobilenet_v1)
- [MobileNetV1 Description](#mobilenetv1-description)
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [[Features]](#features)
- [[Mixed Precision(Ascend)]](#mixed-precisionascend)
- [[Environment Requirements]](#environment-requirements)
- [[Script description]](#script-description)
- [[Script and sample code]](#script-and-sample-code)
- [Training process](#training-process)
- [Usage](#usage)
- [Launch](#launch)
- [Result](#result)
- [Evaluation process](#evaluation-process)
- [Usage](#usage-1)
- [Launch](#launch-1)
- [Result](#result-1)
- [[Model description]](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [MobileNetV1 Description](#contents)
MobileNetV1 is a efficient network for mobile and embedded vision applications. MobileNetV1 is based on a streamlined architecture that uses depth-wise separable convolutions to build light weight deep n.eural networks
[Paper](https://arxiv.org/abs/1704.04861) Howard A G , Zhu M , Chen B , et al. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications[J]. 2017.
## [Model architecture](#contents)
The overall network architecture of MobileNetV1 is show below:
[Link](https://arxiv.org/abs/1704.04861)
## [Dataset](#contents)
Dataset used: [ImageNet2012](http://www.image-net.org/)
- Dataset size 224*224 colorful images in 1000 classes
- Train1,281,167 images
- Test 50,000 images
- Data formatjpeg
- NoteData will be processed in dataset.py
- Download the dataset, the directory structure is as follows:
```bash
└─dataset
├─ilsvrc # train dataset
└─validation_preprocess # evaluate dataset
```
## [Features]
### [Mixed Precision(Ascend)]
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
## [Environment Requirements]
- HardwareAscend
- Prepare hardware environment with Ascend. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [Script description]
### [Script and sample code]
```python
├── MobileNetV1
├── README.md # descriptions about MobileNetV1
├── scripts
│ ├──run_distribute_train.sh # shell script for distribute train
│ ├──run_standalone_train.sh # shell script for standalone train
│ ├──run_eval.sh # shell script for evaluation
├── src
│ ├──config.py # parameter configuration
│ ├──dataset.py # creating dataset
│ ├──lr_generator.py # learning rate config
│ ├──mobilenet_v1_fpn.py # MobileNetV1 architecture
│ ├──CrossEntropySmooth.py # loss function
├── train.py # training script
├── eval.py # evaluation script
```
## [Training process](#contents)
### Usage
You can start training using python or shell scripts. The usage of shell scripts as follows:
- Ascend: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH] (optional)
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link [hccn_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
### Launch
```shell
# training example
python:
Ascend: python train.py --platform Ascend --dataset_path [TRAIN_DATASET_PATH]
shell:
Ascend: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
```
### Result
Training result will be stored in the example path. Checkpoints will be stored at `ckpt_*` by default, and training log will be wrote to `./train_parallel*/log` with the platform Ascend .
```shell
epoch: 89 step: 1251, loss is 2.1829057
Epoch time: 146826.802, per step time: 117.368
epoch: 90 step: 1251, loss is 2.3499017
Epoch time: 150950.623, per step time: 120.664
```
## [Evaluation process](#contents)
### Usage
You can start training using python or shell scripts.If the train method is train or fine tune, should not input the `[CHECKPOINT_PATH]` The usage of shell scripts as follows:
- Ascend: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
### Launch
```shell
# eval example
python:
Ascend: python eval.py --dataset [cifar10|imagenet2012] --dataset_path [VAL_DATASET_PATH] --pretrain_ckpt [CHECKPOINT_PATH]
shell:
Ascend: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
> checkpoint can be produced in training process.
### Result
Inference result will be stored in the example path, you can find result like the followings in `eval/log`.
```shell
result: {'top_5_accuracy': 0.9010016025641026, 'top_1_accuracy': 0.7128004807692307} ckpt=./train_parallel0/ckpt_0/mobilenetv1-90_1251.ckpt
```
## [Model description]
### [Performance](#contents)
#### Training Performance
| Parameters | MobilenetV1 |
| -------------------------- | ------------------------------------------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 * 4, cpu:2.60GHz 192cores, memory:755G |
| uploaded Date | 11/28/2020 |
| MindSpore Version | 1.0.0 |
| Dataset | ImageNet2012 |
| Training Parameters | src/config.py |
| Optimizer | Momentum |
| Loss Function | SoftmaxCrossEntropy |
| outputs | probability |
| Loss | 2.3499017 |
| Accuracy | ACC1[71.28%] |
| Total time | 225 min |
| Params (M) | 3.3 M |
| Checkpoint for Fine tuning | 27.3 M |
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/mobilenetv1) |
## [Description of Random Situation](#contents)
<!-- In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. -->
In train.py, we set the seed which is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,79 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval mobilenet_v1."""
import os
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
set_seed(1)
if args_opt.dataset == 'cifar10':
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
if __name__ == '__main__':
target = args_opt.device_target
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if target != "GPU":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size,
target=target)
step_size = dataset.get_dataset_size()
# define net
net = mobilenet(class_num=config.class_num)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction='mean',
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

@ -0,0 +1,94 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_distribute_train.sh [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "cifar10" ] && [ $1 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $2)
PATH2=$(get_real_path $3)
if [ $# == 4 ]
then
PATH3=$(get_real_path $4)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 4 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=${i}
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
if [ $# == 3 ]
then
python train.py --dataset=$1 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log &
fi
if [ $# == 4 ]
then
python train.py --dataset=$1 --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log &
fi
cd ..
done

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "Usage: sh run_eval.sh [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 != "cifar10" ] && [ $1 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $2)
PATH2=$(get_real_path $3)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset=$1 --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ] && [ $# != 3 ]
then
echo "Usage: sh run_distribute_train.sh [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "cifar10" ] && [ $1 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $2)
if [ $# == 3 ]
then
PATH2=$(get_real_path $3)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 3 ] && [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ $# == 2 ]
then
python train.py --dataset=$1 --dataset_path=$PATH1 &> log &
fi
if [ $# == 3 ]
then
python train.py --dataset=$1 --dataset_path=$PATH1 --pre_trained=$PATH2 &> log &
fi
cd ..

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

@ -0,0 +1,60 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config for mobilenet, cifar10
config1 = ed({
"class_num": 10,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr_decay_mode": "poly",
"lr_init": 0.01,
"lr_end": 0.00001,
"lr_max": 0.1
})
# config for mobilenet, imagenet2012
config2 = ed({
"class_num": 1001,
"batch_size": 256,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "linear",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.8,
"lr_end": 0.0
})

@ -0,0 +1,155 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or evaluate cifar10 dataset for mobilenet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init()
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
# define map operations
trans = []
if do_train:
trans += [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5)
]
trans += [
C.Resize((224, 224)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval imagenet2012 dataset for mobilenet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init()
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
ds = ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

@ -0,0 +1,207 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
"""
Applies three steps decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
return lr_each_step
def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies polynomial decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
return lr_each_step
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies cosine decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
return lr_each_step
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
"""
Applies liner decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
lr_max(float): max learning rate.
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
return lr_each_step
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
elif lr_decay_mode == 'poly':
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
elif lr_decay_mode == 'cosine':
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
else:
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
"""
generate learning rate array with cosine
Args:
lr(float): base learning rate
steps_per_epoch(int): steps size of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epochs of training
global_step(int): the current start index of lr array
Returns:
np.array, learning rate array
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = base_lr * decayed
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[global_step:]
return learning_rate

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore.ops import operations as P
def conv_bn_relu(in_channel, out_channel, kernel_size, stride, depthwise, activation='relu6'):
output = []
output.append(nn.Conv2d(in_channel, out_channel, kernel_size, stride, pad_mode="same",
group=1 if not depthwise else in_channel))
output.append(nn.BatchNorm2d(out_channel))
if activation:
output.append(nn.get_activation(activation))
return nn.SequentialCell(output)
class MobileNetV1(nn.Cell):
"""
MobileNet V1 backbone
"""
def __init__(self, class_num=1001, features_only=False):
super(MobileNetV1, self).__init__()
self.features_only = features_only
cnn = [
conv_bn_relu(3, 32, 3, 2, False), # Conv0
conv_bn_relu(32, 32, 3, 1, True), # Conv1_depthwise
conv_bn_relu(32, 64, 1, 1, False), # Conv1_pointwise
conv_bn_relu(64, 64, 3, 2, True), # Conv2_depthwise
conv_bn_relu(64, 128, 1, 1, False), # Conv2_pointwise
conv_bn_relu(128, 128, 3, 1, True), # Conv3_depthwise
conv_bn_relu(128, 128, 1, 1, False), # Conv3_pointwise
conv_bn_relu(128, 128, 3, 2, True), # Conv4_depthwise
conv_bn_relu(128, 256, 1, 1, False), # Conv4_pointwise
conv_bn_relu(256, 256, 3, 1, True), # Conv5_depthwise
conv_bn_relu(256, 256, 1, 1, False), # Conv5_pointwise
conv_bn_relu(256, 256, 3, 2, True), # Conv6_depthwise
conv_bn_relu(256, 512, 1, 1, False), # Conv6_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv7_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv7_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv8_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv8_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv9_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv9_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv10_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv10_pointwise
conv_bn_relu(512, 512, 3, 1, True), # Conv11_depthwise
conv_bn_relu(512, 512, 1, 1, False), # Conv11_pointwise
conv_bn_relu(512, 512, 3, 2, True), # Conv12_depthwise
conv_bn_relu(512, 1024, 1, 1, False), # Conv12_pointwise
conv_bn_relu(1024, 1024, 3, 1, True), # Conv13_depthwise
conv_bn_relu(1024, 1024, 1, 1, False), # Conv13_pointwise
]
if self.features_only:
self.network = nn.CellList(cnn)
else:
self.network = nn.SequentialCell(cnn)
self.fc = nn.Dense(1024, class_num)
def construct(self, x):
output = x
if self.features_only:
features = ()
for block in self.network:
output = block(output)
features = features + (output,)
return features
output = self.network(x)
output = P.ReduceMean()(output, (2, 3))
output = self.fc(output)
return output
def mobilenet_v1(class_num=1001):
return MobileNetV1(class_num)

@ -0,0 +1,163 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train mobilenet_v1."""
import os
import argparse
import ast
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.common import set_seed
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from src.lr_generator import get_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.mobilenet_v1 import mobilenet_v1 as mobilenet
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset', type=str, default=None, help='Dataset, either cifar10 or imagenet2012')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
args_opt = parser.parse_args()
set_seed(1)
if args_opt.dataset == 'cifar10':
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
if __name__ == '__main__':
target = args_opt.device_target
ckpt_save_dir = config.save_checkpoint_path
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if args_opt.parameter_server:
context.set_ps_context(enable_ps=True)
if args_opt.run_distribute:
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
# GPU target
else:
init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target)
step_size = dataset.get_dataset_size()
# define net
net = mobilenet(class_num=config.class_num)
if args_opt.parameter_server:
net.set_param_ps()
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
# init lr
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
lr = Tensor(lr)
# define opt
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
# define loss, model
if target == "Ascend":
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
else:
# GPU target
if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay,
config.loss_scale)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# Mixed precision
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False)
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenetv1", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# train model
model.train(config.epoch_size - config.pretrain_epoch_size, dataset, callbacks=cb,
sink_size=dataset.get_dataset_size(), dataset_sink_mode=(not args_opt.parameter_server))

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save