!4478 Add an example of training NASNet in MindSpore
Merge pull request !4478 from dessyang/masterpull/4478/MERGE
commit
a6c1fb2c25
@ -0,0 +1,111 @@
|
|||||||
|
# NASNet Example
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
This is an example of training NASNet-A-Mobile in MindSpore.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Install [Mindspore](http://www.mindspore.cn/install/en).
|
||||||
|
- Download the dataset.
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
```shell
|
||||||
|
.
|
||||||
|
└─nasnet
|
||||||
|
├─README.md
|
||||||
|
├─scripts
|
||||||
|
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
|
||||||
|
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
|
||||||
|
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
|
||||||
|
├─src
|
||||||
|
├─config.py # parameter configuration
|
||||||
|
├─dataset.py # data preprocessing
|
||||||
|
├─loss.py # Customized CrossEntropy loss function
|
||||||
|
├─lr_generator.py # learning rate generator
|
||||||
|
├─nasnet_a_mobile.py # network definition
|
||||||
|
├─eval.py # eval net
|
||||||
|
├─export.py # convert checkpoint
|
||||||
|
└─train.py # train net
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parameter Configuration
|
||||||
|
|
||||||
|
Parameters for both training and evaluating can be set in config.py
|
||||||
|
|
||||||
|
```
|
||||||
|
'random_seed': 1, # fix random seed
|
||||||
|
'rank': 0, # local rank of distributed
|
||||||
|
'group_size': 1, # world size of distributed
|
||||||
|
'work_nums': 8, # number of workers to read the data
|
||||||
|
'epoch_size': 250, # total epoch numbers
|
||||||
|
'keep_checkpoint_max': 100, # max numbers to keep checkpoints
|
||||||
|
'ckpt_path': './checkpoint/', # save checkpoint path
|
||||||
|
'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters
|
||||||
|
'batch_size': 32, # input batchsize
|
||||||
|
'num_classes': 1000, # dataset class numbers
|
||||||
|
'label_smooth_factor': 0.1, # label smoothing factor
|
||||||
|
'aux_factor': 0.4, # loss factor of aux logit
|
||||||
|
'lr_init': 0.04, # initiate learning rate
|
||||||
|
'lr_decay_rate': 0.97, # decay rate of learning rate
|
||||||
|
'num_epoch_per_decay': 2.4, # decay epoch number
|
||||||
|
'weight_decay': 0.00004, # weight decay
|
||||||
|
'momentum': 0.9, # momentum
|
||||||
|
'opt_eps': 1.0, # epsilon
|
||||||
|
'rmsprop_decay': 0.9, # rmsprop decay
|
||||||
|
'loss_scale': 1, # loss scale
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
### Train
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
# distribute training example(8p)
|
||||||
|
sh run_distribute_train_for_gpu.sh DATA_DIR
|
||||||
|
# standalone training
|
||||||
|
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# distributed training example(8p) for GPU
|
||||||
|
sh scripts/run_distribute_train_for_gpu.sh /dataset/train
|
||||||
|
# standalone training example for GPU
|
||||||
|
sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Result
|
||||||
|
|
||||||
|
You can find checkpoint file together with result in log.
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
# Evaluation
|
||||||
|
sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Launch
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Evaluation with checkpoint
|
||||||
|
sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/nasnet-a-mobile-rank0-248_10009.ckpt
|
||||||
|
```
|
||||||
|
|
||||||
|
> checkpoint can be produced in training process.
|
||||||
|
|
||||||
|
#### Result
|
||||||
|
|
||||||
|
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
||||||
|
|
@ -0,0 +1,53 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""evaluate imagenet"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.nasnet_a_mobile import NASNetAMobile
|
||||||
|
from src.loss import CrossEntropy_Val
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||||
|
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of nasnet_a_mobile (Default: None)')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
|
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if args_opt.platform == 'Ascend':
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||||
|
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||||
|
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||||
|
load_param_into_net(net, ckpt)
|
||||||
|
net.set_train(False)
|
||||||
|
dataset = create_dataset(args_opt.dataset_path, cfg, False)
|
||||||
|
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes)
|
||||||
|
eval_metrics = {'Loss': nn.Loss(),
|
||||||
|
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||||
|
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||||
|
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||||
|
metrics = model.eval(dataset)
|
||||||
|
print("metric: ", metrics)
|
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
##############export checkpoint file into geir and onnx models#################
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||||
|
|
||||||
|
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||||
|
from src.nasnet_a_mobile import NASNetAMobile
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='checkpoint export')
|
||||||
|
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of nasnet_a_mobile (Default: None)')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
net = NASNetAMobile(num_classes=cfg.num_classes, is_training=False)
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
|
||||||
|
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, cfg.image_size, cfg.image_size]), ms.float32)
|
||||||
|
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
|
||||||
|
export(net, input_arr, file_name=cfg.geir_filename, file_format="GEIR")
|
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
DATA_DIR=$1
|
||||||
|
mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
DEVICE_ID=$1
|
||||||
|
DATA_DIR=$2
|
||||||
|
PATH_CHECKPOINT=$3
|
||||||
|
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &
|
@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
DEVICE_ID=$1
|
||||||
|
DATA_DIR=$2
|
||||||
|
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
||||||
|
|
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in main.py
|
||||||
|
"""
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
|
|
||||||
|
nasnet_a_mobile_config_gpu = edict({
|
||||||
|
'random_seed': 1,
|
||||||
|
'rank': 0,
|
||||||
|
'group_size': 1,
|
||||||
|
'work_nums': 8,
|
||||||
|
'epoch_size': 312,
|
||||||
|
'keep_checkpoint_max': 100,
|
||||||
|
'ckpt_path': './nasnet_a_mobile_checkpoint/',
|
||||||
|
'is_save_on_master': 0,
|
||||||
|
|
||||||
|
### Dataset Config
|
||||||
|
'batch_size': 32,
|
||||||
|
'image_size': 224,
|
||||||
|
'num_classes': 1000,
|
||||||
|
|
||||||
|
### Loss Config
|
||||||
|
'label_smooth_factor': 0.1,
|
||||||
|
'aux_factor': 0.4,
|
||||||
|
|
||||||
|
### Learning Rate Config
|
||||||
|
# 'lr_decay_method': 'exponential',
|
||||||
|
'lr_init': 0.04,
|
||||||
|
'lr_decay_rate': 0.97,
|
||||||
|
'num_epoch_per_decay': 2.4,
|
||||||
|
|
||||||
|
### Optimization Config
|
||||||
|
'weight_decay': 0.00004,
|
||||||
|
'momentum': 0.9,
|
||||||
|
'opt_eps': 1.0,
|
||||||
|
'rmsprop_decay': 0.9,
|
||||||
|
"loss_scale": 1,
|
||||||
|
|
||||||
|
### onnx&air Config
|
||||||
|
'onnx_filename': 'nasnet_a_mobile.onnx',
|
||||||
|
'air_filename': 'nasnet_a_mobile.air'
|
||||||
|
})
|
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Data operations, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
|
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(dataset_path, config, do_train, repeat_num=1):
|
||||||
|
"""
|
||||||
|
create a train or eval dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
config(dict): config of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
rank = config.rank
|
||||||
|
group_size = config.group_size
|
||||||
|
if group_size == 1:
|
||||||
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=config.work_nums, shuffle=True)
|
||||||
|
else:
|
||||||
|
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=config.work_nums, shuffle=True,
|
||||||
|
num_shards=group_size, shard_id=rank)
|
||||||
|
# define map operations
|
||||||
|
if do_train:
|
||||||
|
trans = [
|
||||||
|
C.RandomCropDecodeResize(config.image_size),
|
||||||
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
|
C.RandomColorAdjust(brightness=0.4, saturation=0.5) # fast mode
|
||||||
|
#C.RandomColorAdjust(brightness=0.4, contrast=0.5, saturation=0.5, hue=0.2)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
trans = [
|
||||||
|
C.Decode(),
|
||||||
|
C.Resize(int(config.image_size/0.875)),
|
||||||
|
C.CenterCrop(config.image_size)
|
||||||
|
]
|
||||||
|
trans += [
|
||||||
|
C.Rescale(1.0 / 255.0, 0.0),
|
||||||
|
C.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||||
|
C.HWC2CHW()
|
||||||
|
]
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
|
||||||
|
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
|
||||||
|
# apply batch operations
|
||||||
|
ds = ds.batch(config.batch_size, drop_remainder=True)
|
||||||
|
# apply dataset repeat operation
|
||||||
|
ds = ds.repeat(repeat_num)
|
||||||
|
return ds
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""define evaluation loss function for network."""
|
||||||
|
from mindspore.nn.loss.loss import _Loss
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
import mindspore.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEntropy_Val(_Loss):
|
||||||
|
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
|
||||||
|
def __init__(self, smooth_factor=0, num_classes=1000):
|
||||||
|
super(CrossEntropy_Val, self).__init__()
|
||||||
|
self.onehot = P.OneHot()
|
||||||
|
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||||
|
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||||
|
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
|
self.mean = P.ReduceMean(False)
|
||||||
|
|
||||||
|
def construct(self, logits, label):
|
||||||
|
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||||
|
loss_logit = self.ce(logits, one_hot_label)
|
||||||
|
loss_logit = self.mean(loss_logit, 0)
|
||||||
|
return loss_logit
|
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""learning rate exponential decay generator"""
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False):
|
||||||
|
"""
|
||||||
|
generate learning rate array
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lr_init(float): init learning rate
|
||||||
|
lr_decay_rate (float):
|
||||||
|
total_epochs(int): total epoch of training
|
||||||
|
steps_per_epoch(int): steps of one epoch
|
||||||
|
is_stair(bool): If `True` decay the learning rate at discrete intervals
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array, learning rate array
|
||||||
|
"""
|
||||||
|
lr_each_step = []
|
||||||
|
total_steps = steps_per_epoch * total_epochs
|
||||||
|
decay_steps = steps_per_epoch * num_epoch_per_decay
|
||||||
|
for i in range(total_steps):
|
||||||
|
p = i/decay_steps
|
||||||
|
if is_stair:
|
||||||
|
p = math.floor(p)
|
||||||
|
lr_each_step.append(lr_init * math.pow(lr_decay_rate, p))
|
||||||
|
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||||
|
return learning_rate
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,117 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train imagenet."""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore import ParallelMode
|
||||||
|
from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
|
from mindspore.nn.optim.rmsprop import RMSProp
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore import dataset as de
|
||||||
|
|
||||||
|
from src.config import nasnet_a_mobile_config_gpu as cfg
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
from src.nasnet_a_mobile import NASNetAMobileWithLoss, NASNetAMobileTrainOneStepWithClipGradient
|
||||||
|
from src.lr_generator import get_lr
|
||||||
|
|
||||||
|
|
||||||
|
random.seed(cfg.random_seed)
|
||||||
|
np.random.seed(cfg.random_seed)
|
||||||
|
de.config.set_seed(cfg.random_seed)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='image classification training')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
|
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||||
|
parser.add_argument('--is_distributed', action='store_true', default=False,
|
||||||
|
help='distributed training')
|
||||||
|
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||||
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||||
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||||
|
|
||||||
|
# init distributed
|
||||||
|
if args_opt.is_distributed:
|
||||||
|
if args_opt.platform == "Ascend":
|
||||||
|
init()
|
||||||
|
else:
|
||||||
|
init("nccl")
|
||||||
|
cfg.rank = get_rank()
|
||||||
|
cfg.group_size = get_group_size()
|
||||||
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||||
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||||
|
parameter_broadcast=True, mirror_mean=True)
|
||||||
|
else:
|
||||||
|
cfg.rank = 0
|
||||||
|
cfg.group_size = 1
|
||||||
|
|
||||||
|
# dataloader
|
||||||
|
dataset = create_dataset(args_opt.dataset_path, cfg, True)
|
||||||
|
batches_per_epoch = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
# network
|
||||||
|
net_with_loss = NASNetAMobileWithLoss(cfg)
|
||||||
|
if args_opt.resume:
|
||||||
|
ckpt = load_checkpoint(args_opt.resume)
|
||||||
|
load_param_into_net(net_with_loss, ckpt)
|
||||||
|
|
||||||
|
# learning rate schedule
|
||||||
|
lr = get_lr(lr_init=cfg.lr_init, lr_decay_rate=cfg.lr_decay_rate,
|
||||||
|
num_epoch_per_decay=cfg.num_epoch_per_decay, total_epochs=cfg.epoch_size,
|
||||||
|
steps_per_epoch=batches_per_epoch, is_stair=True)
|
||||||
|
lr = Tensor(lr)
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
decayed_params = []
|
||||||
|
no_decayed_params = []
|
||||||
|
for param in net_with_loss.trainable_params():
|
||||||
|
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||||
|
decayed_params.append(param)
|
||||||
|
else:
|
||||||
|
no_decayed_params.append(param)
|
||||||
|
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||||
|
{'params': no_decayed_params},
|
||||||
|
{'order_params': net_with_loss.trainable_params()}]
|
||||||
|
optimizer = RMSProp(group_params, lr, decay=cfg.rmsprop_decay, weight_decay=cfg.weight_decay,
|
||||||
|
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
|
||||||
|
|
||||||
|
net_with_grads = NASNetAMobileTrainOneStepWithClipGradient(net_with_loss, optimizer)
|
||||||
|
net_with_grads.set_train()
|
||||||
|
model = Model(net_with_grads)
|
||||||
|
|
||||||
|
print("============== Starting Training ==============")
|
||||||
|
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||||
|
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||||
|
callbacks = [loss_cb, time_cb]
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix=f"nasnet-a-mobile-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||||
|
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||||
|
if cfg.rank == 0:
|
||||||
|
callbacks.append(ckpoint_cb)
|
||||||
|
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||||
|
else:
|
||||||
|
callbacks.append(ckpoint_cb)
|
||||||
|
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||||
|
print("train success")
|
Loading…
Reference in new issue