Adding TinyNet to Model Zoo

Adding TinyNet (https://arxiv.org/abs/2010.14819) MindSpore implementation to model Zoo
pull/8103/head
yanglf1121 4 years ago
parent fc5a3b7d97
commit 882301f4b5

@ -60,6 +60,7 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [GhostNet_Quant](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ghostnet_quant/README.md)
- [ResNet50-0.65x](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/resnet50_adv_pruning/README.md)
- [SSD_GhostNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ssd_ghostnet/README.md)
- [TinyNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/tinynet/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)

@ -0,0 +1,154 @@
# Contents
- [TinyNet Description](#tinynet-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [TinyNet Description](#contents)
TinyNets are a series of lightweight models obtained by twisting resolution, depth and width with a data-driven tiny formula. TinyNet outperforms EfficientNet and MobileNetV3.
[Paper](https://arxiv.org/abs/2010.14819): Kai Han, Yunhe Wang, Qiulin Zhang, Wei Zhang, Chunjing Xu, Tong Zhang. Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets. In NeurIPS 2020.
Note: We have only released TinyNet-C for now, and will release other TinyNets soon.
# [Model architecture](#contents)
The overall network architecture of TinyNet is show below:
[Link](https://arxiv.org/abs/2010.14819)
# [Dataset](#contents)
Dataset used: [ImageNet 2012](http://image-net.org/challenges/LSVRC/2012/)
- Dataset size:
- Train: 1.2 million images in 1,000 classes
- Test: 50,000 validation images in 1,000 classes
- Data format: RGB images.
- Note: Data will be processed in src/dataset/dataset.py
# [Environment Requirements](#contents)
- Hardware (GPU)
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```
.tinynet
├── Readme.md # descriptions about tinynet
├── script
│ ├── eval.sh # evaluation script
│ ├── train_1p_gpu.sh # training script on single GPU
│ └── train_distributed_gpu.sh # distributed training script on multiple GPUs
├── src
│ ├── callback.py # loss and checkpoint callbacks
│ ├── dataset.py # data processing
│ ├── loss.py # label-smoothing cross-entropy loss function
│ ├── tinynet.py # tinynet architecture
│ └── utils.py # utility functions
├── eval.py # evaluation interface
└── train.py # training interface
```
## [Training process](#contents)
### Launch
```
# training on single GPU
sh train_1p_gpu.sh
# training on multiple GPUs, the number after -n indicates how many GPUs will be used for training
sh train_distributed_gpu.sh -n 8
```
Inside train.sh, there are hyperparameters that can be adjusted during training, for example:
```
--model tinynet_c model to be used for training
--drop 0.2 dropout rate
--drop-connect 0 drop connect rate
--num-classes 1000 number of classes for training
--opt-eps 0.001 optimizer's epsilon
--lr 0.048 learning rate
--batch-size 128 batch size
--decay-epochs 2.4 learning rate decays every 2.4 epoch
--warmup-lr 1e-6 warm up learning rate
--warmup-epochs 3 learning rate warm up epoch
--decay-rate 0.97 learning rate decay rate
--ema-decay 0.9999 decay factor for model weights moving average
--weight-decay 1e-5 optimizer's weight decay
--epochs 450 number of epochs to be trained
--ckpt_save_epoch 1 checkpoint saving interval
--workers 8 number of processes for loading data
--amp_level O0 training auto-mixed precision
--opt rmsprop optimizers, currently we support SGD and RMSProp
--data_path /path_to_ImageNet/
--GPU using GPU for training
--dataset_sink using sink mode
```
The config above was used to train tinynets on ImageNet (change drop-connect to 0.2 for training tinynet-b)
> checkpoints will be saved in the ./device_{rank_id} folder (single GPU)
or ./device_parallel folder (multiple GPUs)
## [Eval process](#contents)
### Launch
```
# infer example
sh eval.sh
```
Inside the eval.sh, there are configs that can be adjusted during inference, for example:
```
--num-classes 1000
--batch-size 128
--workers 8
--data_path /path_to_ImageNet/
--GPU
--ckpt /path_to_EMA_checkpoint/
--dataset_sink > tinynet_c_eval.log 2>&1 &
```
> checkpoint can be produced in training process.
# [Model Description](#contents)
## [Performance](#contents)
#### Evaluation Performance
| Model | FLOPs | Latency* | ImageNet Top-1 |
| ------------------- | ----- | -------- | -------------- |
| EfficientNet-B0 | 387M | 99.85 ms | 76.7% |
| TinyNet-A | 339M | 81.30 ms | 76.8% |
| EfficientNet-B^{-4} | 24M | 11.54 ms | 56.7% |
| TinyNet-E | 24M | 9.18 ms | 59.9% |
*Latency is measured using MS Lite on Huawei P40 smartphone.
*More details in [Paper](https://arxiv.org/abs/2010.14819).
# [Description of Random Situation](#contents)
We set the seed inside dataset.py. We also use random seed in train.py.
# [Model Zoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Inference Interface"""
import sys
import os
import argparse
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore import context
from src.dataset import create_dataset_val
from src.utils import count_params
from src.loss import LabelSmoothingCrossEntropy
from src.tinynet import tinynet
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/',
metavar='DIR', help='path to dataset')
parser.add_argument('--model', default='tinynet_c', type=str, metavar='MODEL',
help='Name of model to train (default: "tinynet_c"')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='label smoothing (default: 0.1)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--ckpt', type=str, default=None,
help='model checkpoint to load')
parser.add_argument('--GPU', action='store_true', default=True,
help='Use GPU for training (default: True)')
parser.add_argument('--dataset_sink', action='store_true', default=True)
def main():
"""Main entrance for training"""
args = parser.parse_args()
print(sys.argv)
context.set_context(mode=context.GRAPH_MODE)
if args.GPU:
context.set_context(device_target='GPU')
# parse model argument
assert args.model.startswith(
"tinynet"), "Only Tinynet models are supported."
_, sub_name = args.model.split("_")
net = tinynet(sub_model=sub_name,
num_classes=args.num_classes,
drop_rate=0.0,
drop_connect_rate=0.0,
global_pool="avg",
bn_tf=False,
bn_momentum=None,
bn_eps=None)
print("Total number of parameters:", count_params(net))
input_size = net.default_cfg['input_size'][1]
val_data_url = os.path.join(args.data_path, 'val')
val_dataset = create_dataset_val(args.batch_size,
val_data_url,
workers=args.workers,
distributed=False,
input_size=input_size)
loss = LabelSmoothingCrossEntropy(smooth_factor=args.smoothing,
num_classes=args.num_classes)
loss.add_flags_recursive(fp32=True, fp16=False)
eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
ckpt = load_checkpoint(args.ckpt)
load_param_into_net(net, ckpt)
net.set_train(False)
model = Model(net, loss, metrics=eval_metrics)
metrics = model.eval(val_dataset, dataset_sink_mode=False)
print(metrics)
if __name__ == '__main__':
main()

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cd ../ || exit
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_SIZE=1
export start=0
export value=$((start + RANK_SIZE))
export curtime
curtime=$(date '+%Y%m%d-%H%M%S')
echo "$curtime"
rm ${current_exec_path}/device${start}_$curtime/ -rf
mkdir ${current_exec_path}/device${start}_$curtime
cd ${current_exec_path}/device${start}_$curtime || exit
export RANK_ID=start
export DEVICE_ID=start
time python3 ${current_exec_path}/eval.py \
--model tinynet_c \
--num-classes 1000 \
--batch-size 128 \
--workers 8 \
--data_path /path_to_ImageNet/\
--GPU \
--ckpt /path_to_ckpt/ \
--dataset_sink > tinynet_c_eval.log 2>&1 &

@ -0,0 +1,59 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
cd ../ || exit
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_SIZE=1
export start=0
export value=$(($start+$RANK_SIZE))
export curtime
curtime=$(date '+%Y%m%d-%H%M%S')
echo $curtime
echo "rank_id = ${start}"
rm ${current_exec_path}/device_$start/ -rf
mkdir ${current_exec_path}/device_$start
cd ${current_exec_path}/device_$start || exit
export RANK_ID=$start
export DEVICE_ID=$start
time python3 ${current_exec_path}/train.py \
--model tinynet_c \
--drop 0.2 \
--drop-connect 0 \
--num-classes 1000 \
--opt-eps 0.001 \
--lr 0.048 \
--batch-size 128 \
--decay-epochs 2.4 \
--warmup-lr 1e-6 \
--warmup-epochs 3 \
--decay-rate 0.97 \
--ema-decay 0.9999 \
--weight-decay 1e-5 \
--epochs 100\
--ckpt_save_epoch 1 \
--workers 8 \
--amp_level O0 \
--opt rmsprop \
--data_path /path_to_ImageNet/ \
--GPU \
--dataset_sink > tinynet_c.log 2>&1 &
cd ${current_exec_path} || exit

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# below help function was adapted from
# https://unix.stackexchange.com/questions/31414/how-can-i-pass-a-command-line-argument-into-a-shell-script
helpFunction()
{
echo ""
echo "Usage: $0 -n num_device"
echo -e "\t-n how many gpus to use for training"
exit 1 # Exit script after printing help
}
while getopts "n:" opt
do
case "$opt" in
n ) num_device="$OPTARG" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
done
# Print helpFunction in case parameters are empty
if [ -z "$num_device" ]
then
echo "Some or all of the parameters are empty";
helpFunction
fi
# Begin script in case all parameters are correct
echo "$num_device"
cd ../ || exit
current_exec_path=$(pwd)
echo ${current_exec_path}
export SLOG_PRINT_TO_STDOUT=0
export RANK_SIZE=$num_device
export curtime
curtime=$(date '+%Y%m%d-%H%M%S')
echo $curtime
echo $curtime >> starttime
rm ${current_exec_path}/device_parallel/ -rf
mkdir ${current_exec_path}/device_parallel
cd ${current_exec_path}/device_parallel || exit
echo $curtime >> starttime
time mpirun -n $RANK_SIZE --allow-run-as-root python3 ${current_exec_path}/train.py \
--model tinynet_c \
--drop 0.2 \
--drop-connect 0 \
--num-classes 1000 \
--opt-eps 0.001 \
--lr 0.048 \
--batch-size 128 \
--decay-epochs 2.4 \
--warmup-lr 1e-6 \
--warmup-epochs 3 \
--decay-rate 0.97 \
--ema-decay 0.9999 \
--weight-decay 1e-5 \
--per_print_times 100 \
--epochs 450 \
--ckpt_save_epoch 1 \
--workers 8 \
--amp_level O0 \
--opt rmsprop \
--distributed \
--data_path /path_to_ImageNet/ \
--GPU \
--dataset_sink > tinynet_c.log 2>&1 &

@ -0,0 +1,203 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""custom callbacks for ema and loss"""
from copy import deepcopy
import numpy as np
from mindspore.train.callback import Callback
from mindspore.common.parameter import Parameter
from mindspore.train.serialization import save_checkpoint
from mindspore.nn import Loss, Top1CategoricalAccuracy, Top5CategoricalAccuracy
from mindspore.train.model import Model
from mindspore import Tensor
def load_nparray_into_net(net, array_dict):
"""
Loads dictionary of numpy arrays into network.
Args:
net (Cell): Cell network.
array_dict (dict): dictionary of numpy array format model weights.
"""
param_not_load = []
for _, param in net.parameters_and_names():
if param.name in array_dict:
new_param = array_dict[param.name]
param.set_data(Parameter(new_param.copy(), name=param.name))
else:
param_not_load.append(param.name)
return param_not_load
class EmaEvalCallBack(Callback):
"""
Call back that will evaluate the model and save model checkpoint at
the end of training epoch.
Args:
model: Mindspore model instance.
ema_network: step-wise exponential moving average for ema_network.
eval_dataset: the evaluation daatset.
decay (float): ema decay.
save_epoch (int): defines how often to save checkpoint.
dataset_sink_mode (bool): whether to use data sink mode.
start_epoch (int): which epoch to start/resume training.
"""
def __init__(self, model, ema_network, eval_dataset, loss_fn, decay=0.999,
save_epoch=1, dataset_sink_mode=True, start_epoch=0):
self.model = model
self.ema_network = ema_network
self.eval_dataset = eval_dataset
self.loss_fn = loss_fn
self.decay = decay
self.save_epoch = save_epoch
self.shadow = {}
self.ema_accuracy = {}
self.best_ema_accuracy = 0
self.best_accuracy = 0
self.best_ema_epoch = 0
self.best_epoch = 0
self._start_epoch = start_epoch
self.eval_metrics = {'Validation-Loss': Loss(),
'Top1-Acc': Top1CategoricalAccuracy(),
'Top5-Acc': Top5CategoricalAccuracy()}
self.dataset_sink_mode = dataset_sink_mode
def begin(self, run_context):
"""Initialize the EMA parameters """
cb_params = run_context.original_args()
for _, param in cb_params.network.parameters_and_names():
self.shadow[param.name] = deepcopy(param.data.asnumpy())
def step_end(self, run_context):
"""Update the EMA parameters"""
cb_params = run_context.original_args()
for _, param in cb_params.network.parameters_and_names():
new_average = (1.0 - self.decay) * param.data.asnumpy().copy() + \
self.decay * self.shadow[param.name]
self.shadow[param.name] = new_average
def epoch_end(self, run_context):
"""evaluate the model and ema-model at the end of each epoch"""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num + self._start_epoch - 1
save_ckpt = (cur_epoch % self.save_epoch == 0)
acc = self.model.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
print("Model Accuracy:", acc)
load_nparray_into_net(self.ema_network, self.shadow)
self.ema_network.set_train(False)
model_ema = Model(self.ema_network, loss_fn=self.loss_fn,
metrics=self.eval_metrics)
ema_acc = model_ema.eval(
self.eval_dataset, dataset_sink_mode=self.dataset_sink_mode)
print("EMA-Model Accuracy:", ema_acc)
self.ema_accuracy[cur_epoch] = ema_acc["Top1-Acc"]
output = [{"name": k, "data": Tensor(v)}
for k, v in self.shadow.items()]
if self.best_ema_accuracy < ema_acc["Top1-Acc"]:
self.best_ema_accuracy = ema_acc["Top1-Acc"]
self.best_ema_epoch = cur_epoch
save_checkpoint(output, "ema_best.ckpt")
if self.best_accuracy < acc["Top1-Acc"]:
self.best_accuracy = acc["Top1-Acc"]
self.best_epoch = cur_epoch
print("Best Model Accuracy: %s, at epoch %s" %
(self.best_accuracy, self.best_epoch))
print("Best EMA-Model Accuracy: %s, at epoch %s" %
(self.best_ema_accuracy, self.best_ema_epoch))
if save_ckpt:
# Save the ema_model checkpoints
ckpt = "{}-{}.ckpt".format("ema", cur_epoch)
save_checkpoint(output, ckpt)
save_checkpoint(output, "ema_last.ckpt")
# Save the model checkpoints
save_checkpoint(cb_params.train_network, "last.ckpt")
print("Top 10 EMA-Model Accuracies: ")
count = 0
for epoch in sorted(self.ema_accuracy, key=self.ema_accuracy.get,
reverse=True):
if count == 10:
break
print("epoch: %s, Top-1: %s)" % (epoch, self.ema_accuracy[epoch]))
count += 1
class LossMonitor(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF, it will terminate training.
Note:
If per_print_times is 0, do not print loss.
Args:
lr_array (numpy.array): scheduled learning rate.
total_epochs (int): Total number of epochs for training.
per_print_times (int): Print the loss every time. Default: 1.
start_epoch (int): which epoch to start, used when resume from a
certain epoch.
Raises:
ValueError: If print_step is not an integer or less than zero.
"""
def __init__(self, lr_array, total_epochs, per_print_times=1, start_epoch=0):
super(LossMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self._lr_array = lr_array
self._total_epochs = total_epochs
self._start_epoch = start_epoch
def step_end(self, run_context):
"""log epoch, step, loss and learning rate"""
cb_params = run_context.original_args()
loss = cb_params.net_outputs
cur_epoch_num = cb_params.cur_epoch_num + self._start_epoch - 1
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
global_step = cb_params.cur_step_num - 1
cur_step_in_epoch = global_step % cb_params.batch_num + 1
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cur_step_in_epoch % self._per_print_times == 0:
print("epoch: %s/%s, step: %s/%s, loss is %s, learning rate: %s"
% (cur_epoch_num, self._total_epochs, cur_step_in_epoch,
cb_params.batch_num, loss, self._lr_array[global_step]),
flush=True)

@ -0,0 +1,143 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Data operations, will be used in train.py and eval.py"""
import math
import os
import numpy as np
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.dataset.transforms.py_transforms as py_transforms
import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
from mindspore.communication.management import get_rank, get_group_size
from mindspore.dataset.vision import Inter
# values that should remain constant
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# data preprocess configs
SCALE = (0.08, 1.0)
RATIO = (3./4., 4./3.)
ds.config.set_seed(1)
def split_imgs_and_labels(imgs, labels, batchInfo):
"""split data into labels and images"""
ret_imgs = []
ret_labels = []
for i, image in enumerate(imgs):
ret_imgs.append(image)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False,
input_size=224, color_jitter=0.4):
"""Creat ImageNet training dataset"""
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
decode_op = py_vision.Decode()
type_cast_op = c_transforms.TypeCast(mstype.int32)
random_resize_crop_bicubic = py_vision.RandomResizedCrop(size=(input_size, input_size),
scale=SCALE, ratio=RATIO,
interpolation=Inter.BICUBIC)
random_horizontal_flip_op = py_vision.RandomHorizontalFlip(0.5)
adjust_range = (max(0, 1 - color_jitter), 1 + color_jitter)
random_color_jitter_op = py_vision.RandomColorAdjust(brightness=adjust_range,
contrast=adjust_range,
saturation=adjust_range)
to_tensor = py_vision.ToTensor()
nromlize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
# assemble all the transforms
image_ops = py_transforms.Compose([decode_op, random_resize_crop_bicubic,
random_horizontal_flip_op, random_color_jitter_op, to_tensor, nromlize_op])
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
# batch dealing
ds_train = dataset_train.batch(batch_size,
per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False,
input_size=224):
"""Creat ImageNet validation dataset"""
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id)
scale_size = None
if isinstance(input_size, tuple):
assert len(input_size) == 2
if input_size[-1] == input_size[-2]:
scale_size = int(math.floor(input_size[0] / DEFAULT_CROP_PCT))
else:
scale_size = tuple([int(x / DEFAULT_CROP_PCT) for x in input_size])
else:
scale_size = int(math.floor(input_size / DEFAULT_CROP_PCT))
type_cast_op = c_transforms.TypeCast(mstype.int32)
decode_op = py_vision.Decode()
resize_op = py_vision.Resize(size=scale_size, interpolation=Inter.BICUBIC)
center_crop = py_vision.CenterCrop(size=input_size)
to_tensor = py_vision.ToTensor()
nromlize_op = py_vision.Normalize(
IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
image_ops = py_transforms.Compose([decode_op, resize_op, center_crop,
to_tensor, nromlize_op])
dataset = dataset.map(input_columns=["label"], operations=type_cast_op,
num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=image_ops,
num_parallel_workers=workers)
dataset = dataset.batch(batch_size, per_batch_map=split_imgs_and_labels,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
dataset = dataset.repeat(1)
return dataset

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()
def construct(self, logits, label):
label = self.cast(label, mstype.int32)
one_hot_label = self.onehot(label, F.shape(
logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

File diff suppressed because it is too large Load Diff

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""model utils"""
import math
import argparse
import numpy as np
def str2bool(value):
"""Convert string arguments to bool type"""
if value.lower() in ('yes', 'true', 't', 'y', '1'):
return True
if value.lower() in ('no', 'false', 'f', 'n', '0'):
return False
raise argparse.ArgumentTypeError('Boolean value expected.')
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_epochs=1, decay_rate=0.9,
warmup_epochs=0., warmup_lr_init=0., global_epoch=0):
"""Get scheduled learning rate"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
global_steps = steps_per_epoch * global_epoch
self_warmup_delta = ((base_lr - warmup_lr_init) / \
warmup_epochs) if warmup_epochs > 0 else 0
self_decay_rate = decay_rate if decay_rate < 1 else 1/decay_rate
for i in range(total_steps):
epochs = math.floor(i/steps_per_epoch)
cond = 1 if (epochs < warmup_epochs) else 0
warmup_lr = warmup_lr_init + epochs * self_warmup_delta
decay_nums = math.floor(epochs / decay_epochs)
decay_rate = math.pow(self_decay_rate, decay_nums)
decay_lr = base_lr * decay_rate
lr = cond * warmup_lr + (1 - cond) * decay_lr
lr_each_step.append(lr)
lr_each_step = lr_each_step[global_steps:]
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def add_weight_decay(net, weight_decay=1e-5, skip_list=None):
"""Apply weight decay to only conv and dense layers (len(shape) > =2)
Args:
net (mindspore.nn.Cell): Mindspore network instance
weight_decay (float): weight decay tobe used.
skip_list (tuple): list of parameter names without weight decay
Returns:
A list of group of parameters, separated by different weight decay.
"""
decay = []
no_decay = []
if not skip_list:
skip_list = ()
for param in net.trainable_params():
if len(param.shape) == 1 or \
param.name.endswith(".bias") or \
param.name in skip_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def count_params(net):
"""Count number of parameters in the network
Args:
net (mindspore.nn.Cell): Mindspore network instance
Returns:
total_params (int): Total number of trainable params
"""
total_params = 0
for param in net.trainable_params():
total_params += np.prod(param.shape)
return total_params

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save