parent
9c79b9d712
commit
bd4e441862
@ -0,0 +1,111 @@
|
||||
# EfficientNet-B0 Example
|
||||
|
||||
## Description
|
||||
|
||||
This is an example of training EfficientNet-B0 in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [Mindspore](http://www.mindspore.cn/install/en).
|
||||
- Download the dataset.
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└─nasnet
|
||||
├─README.md
|
||||
├─scripts
|
||||
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
|
||||
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
|
||||
├─src
|
||||
├─config.py # parameter configuration
|
||||
├─dataset.py # data preprocessing
|
||||
├─efficientnet.py # network definition
|
||||
├─loss.py # Customized loss function
|
||||
├─transform_utils.py # random augment utils
|
||||
├─transform.py # random augment class
|
||||
├─eval.py # eval net
|
||||
└─train.py # train net
|
||||
|
||||
```
|
||||
|
||||
## Parameter Configuration
|
||||
|
||||
Parameters for both training and evaluating can be set in config.py
|
||||
|
||||
```
|
||||
'random_seed': 1, # fix random seed
|
||||
'model': 'efficientnet_b0', # model name
|
||||
'drop': 0.2, # dropout rate
|
||||
'drop_connect': 0.2, # drop connect rate
|
||||
'opt_eps': 0.001, # optimizer epsilon
|
||||
'lr': 0.064, # learning rate LR
|
||||
'batch_size': 128, # batch size
|
||||
'decay_epochs': 2.4, # epoch interval to decay LR
|
||||
'warmup_epochs': 5, # epochs to warmup LR
|
||||
'decay_rate': 0.97, # LR decay rate
|
||||
'weight_decay': 1e-5, # weight decay
|
||||
'epochs': 600, # number of epochs to train
|
||||
'workers': 8, # number of data processing processes
|
||||
'amp_level': 'O0', # amp level
|
||||
'opt': 'rmsprop', # optimizer
|
||||
'num_classes': 1000, # number of classes
|
||||
'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc"
|
||||
'momentum': 0.9, # optimizer momentum
|
||||
'warmup_lr_init': 0.0001, # init warmup LR
|
||||
'smoothing': 0.1, # label smoothing factor
|
||||
'bn_tf': False, # use Tensorflow BatchNorm defaults
|
||||
'keep_checkpoint_max': 10, # max number ckpts to keep
|
||||
'loss_scale': 1024, # loss scale
|
||||
'resume_start_epoch': 0, # resume start epoch
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# distribute training example(8p)
|
||||
sh run_distribute_train_for_gpu.sh DATA_DIR
|
||||
# standalone training
|
||||
sh run_standalone_train_for_gpu.sh DATA_DIR DEVICE_ID
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# distributed training example(8p) for GPU
|
||||
sh scripts/run_distribute_train_for_gpu.sh /dataset
|
||||
# standalone training example for GPU
|
||||
sh scripts/run_standalone_train_for_gpu.sh /dataset 0
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
||||
You can find checkpoint file together with result in log.
|
||||
|
||||
### Evaluation
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# Evaluation
|
||||
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```bash
|
||||
# Evaluation with checkpoint
|
||||
sh scripts/run_eval_for_gpu.sh /dataset 0 ./checkpoint/efficientnet_b0-600_1251.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
#### Result
|
||||
|
||||
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
|
@ -0,0 +1,61 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate imagenet"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import efficientnet_b0_config_gpu as cfg
|
||||
from src.dataset import create_dataset_val
|
||||
from src.efficientnet import efficientnet_b0
|
||||
from src.loss import LabelSmoothingCrossEntropy
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)')
|
||||
parser.add_argument('--data_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
|
||||
|
||||
net = efficientnet_b0(num_classes=cfg.num_classes,
|
||||
drop_rate=cfg.drop,
|
||||
drop_connect_rate=cfg.drop_connect,
|
||||
global_pool=cfg.gp,
|
||||
bn_tf=cfg.bn_tf,
|
||||
)
|
||||
|
||||
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
val_data_url = os.path.join(args_opt.data_path, 'val')
|
||||
dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False)
|
||||
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
|
||||
metrics = model.eval(dataset)
|
||||
print("metric: ", metrics)
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
curtime=`date '+%Y%m%d-%H%M%S'`
|
||||
RANK_SIZE=8
|
||||
|
||||
rm ${current_exec_path}/device_parallel/ -rf
|
||||
mkdir ${current_exec_path}/device_parallel
|
||||
echo ${curtime} > ${current_exec_path}/device_parallel/starttime
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE python ${current_exec_path}/train.py \
|
||||
--GPU \
|
||||
--distributed \
|
||||
--data_path ${DATA_DIR} \
|
||||
--cur_time ${curtime} > ${current_exec_path}/device_parallel/efficientnet_b0.log 2>&1 &
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
DEVICE_ID=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
curtime=`date '+%Y%m%d-%H%M%S'`
|
||||
|
||||
echo ${curtime} > ${current_exec_path}/eval_starttime
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ./eval.py --platform 'GPU' --data_path ${DATA_DIR} --checkpoint ${PATH_CHECKPOINT} > ${current_exec_path}/eval.log 2>&1 &
|
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
DEVICE_ID=$2
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
curtime=`date '+%Y%m%d-%H%M%S'`
|
||||
|
||||
rm ${current_exec_path}/device_${DEVICE_ID}/ -rf
|
||||
mkdir ${current_exec_path}/device_${DEVICE_ID}
|
||||
echo ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/starttime
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ${current_exec_path}/train.py \
|
||||
--GPU \
|
||||
--data_path ${DATA_DIR} \
|
||||
--cur_time ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/efficientnet_b0.log 2>&1 &
|
@ -0,0 +1,47 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
efficientnet_b0_config_gpu = edict({
|
||||
'random_seed': 1,
|
||||
'model': 'efficientnet_b0',
|
||||
'drop': 0.2,
|
||||
'drop_connect': 0.2,
|
||||
'opt_eps': 0.001,
|
||||
'lr': 0.064,
|
||||
'batch_size': 128,
|
||||
'decay_epochs': 2.4,
|
||||
'warmup_epochs': 5,
|
||||
'decay_rate': 0.97,
|
||||
'weight_decay': 1e-5,
|
||||
'epochs': 600,
|
||||
'workers': 8,
|
||||
'amp_level': 'O0',
|
||||
'opt': 'rmsprop',
|
||||
'num_classes': 1000,
|
||||
#'Type of global pool, "avg", "max", "avgmax", "avgmaxc"
|
||||
'gp': 'avg',
|
||||
'momentum': 0.9,
|
||||
'warmup_lr_init': 0.0001,
|
||||
'smoothing': 0.1,
|
||||
#Use Tensorflow BatchNorm defaults for models that support it
|
||||
'bn_tf': False,
|
||||
'keep_checkpoint_max': 10,
|
||||
'loss_scale': 1024,
|
||||
'resume_start_epoch': 0,
|
||||
})
|
@ -0,0 +1,125 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
from mindspore.communication.management import get_group_size, get_rank
|
||||
from mindspore.dataset.vision import Inter
|
||||
|
||||
from src.config import efficientnet_b0_config_gpu as cfg
|
||||
from src.transform import RandAugment
|
||||
|
||||
ds.config.set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
img_size = (224, 224)
|
||||
crop_pct = 0.875
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
inter_method = 'bilinear'
|
||||
resize_value = 224 # img_size
|
||||
scale = (0.08, 1.0)
|
||||
ratio = (3./4., 4./3.)
|
||||
inter_str = 'bicubic'
|
||||
|
||||
def str2MsInter(method):
|
||||
if method == 'bicubic':
|
||||
return Inter.BICUBIC
|
||||
if method == 'nearest':
|
||||
return Inter.NEAREST
|
||||
return Inter.BILINEAR
|
||||
|
||||
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
|
||||
if not os.path.exists(train_data_url):
|
||||
raise ValueError('Path not exists')
|
||||
interpolation = str2MsInter(inter_str)
|
||||
|
||||
c_decode_op = C.Decode()
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio,
|
||||
interpolation=interpolation)
|
||||
random_horizontal_flip_op = C.RandomHorizontalFlip(0.5)
|
||||
|
||||
efficient_rand_augment = RandAugment()
|
||||
|
||||
image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op]
|
||||
|
||||
rank_id = get_rank() if distributed else 0
|
||||
rank_size = get_group_size() if distributed else 1
|
||||
|
||||
dataset_train = ds.ImageFolderDataset(train_data_url,
|
||||
num_parallel_workers=workers,
|
||||
shuffle=True,
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id)
|
||||
dataset_train = dataset_train.map(input_columns=["image"],
|
||||
operations=image_ops,
|
||||
num_parallel_workers=workers)
|
||||
dataset_train = dataset_train.map(input_columns=["label"],
|
||||
operations=type_cast_op,
|
||||
num_parallel_workers=workers)
|
||||
ds_train = dataset_train.batch(batch_size,
|
||||
per_batch_map=efficient_rand_augment,
|
||||
input_columns=["image", "label"],
|
||||
num_parallel_workers=2,
|
||||
drop_remainder=True)
|
||||
ds_train = ds_train.repeat(1)
|
||||
return ds_train
|
||||
|
||||
|
||||
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False):
|
||||
if not os.path.exists(val_data_url):
|
||||
raise ValueError('Path not exists')
|
||||
rank_id = get_rank() if distributed else 0
|
||||
rank_size = get_group_size() if distributed else 1
|
||||
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
|
||||
num_shards=rank_size, shard_id=rank_id, shuffle=False)
|
||||
scale_size = None
|
||||
interpolation = str2MsInter(inter_method)
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
decode_op = C.Decode()
|
||||
resize_op = C.Resize(size=scale_size, interpolation=interpolation)
|
||||
center_crop = C.CenterCrop(size=224)
|
||||
rescale_op = C.Rescale(rescale, shift)
|
||||
normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||
changeswap_op = C.HWC2CHW()
|
||||
|
||||
ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
|
||||
|
||||
dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers)
|
||||
dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers)
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=workers)
|
||||
dataset = dataset.repeat(1)
|
||||
return dataset
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,37 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
|
||||
class LabelSmoothingCrossEntropy(_Loss):
|
||||
|
||||
def __init__(self, smooth_factor=0.1, num_classes=1000):
|
||||
super(LabelSmoothingCrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logits, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
return loss_logit
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
random augment class
|
||||
"""
|
||||
import numpy as np
|
||||
import mindspore.dataset.vision.py_transforms as P
|
||||
from src import transform_utils
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
class RandAugment:
|
||||
# config_str belongs to str
|
||||
# hparams belongs to dict
|
||||
def __init__(self, config_str="rand-m9-mstd0.5", hparams=None):
|
||||
hparams = hparams if hparams is not None else {}
|
||||
self.config_str = config_str
|
||||
self.hparams = hparams
|
||||
|
||||
def __call__(self, imgs, labels, batchInfo):
|
||||
# assert the imgs objetc are pil_images
|
||||
ret_imgs = []
|
||||
ret_labels = []
|
||||
py_to_pil_op = P.ToPIL()
|
||||
to_tensor = P.ToTensor()
|
||||
normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
|
||||
rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams)
|
||||
for i, image in enumerate(imgs):
|
||||
img_pil = py_to_pil_op(image)
|
||||
img_pil = rand_augment_ops(img_pil)
|
||||
img_array = to_tensor(img_pil)
|
||||
img_array = normalize_op(img_array)
|
||||
ret_imgs.append(img_array)
|
||||
ret_labels.append(labels[i])
|
||||
return np.array(ret_imgs), np.array(ret_labels)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,191 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train imagenet."""
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import mindspore
|
||||
from mindspore import Tensor, context
|
||||
from mindspore.communication.management import get_group_size, get_rank, init
|
||||
from mindspore.nn import SGD, RMSProp
|
||||
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
|
||||
ModelCheckpoint, TimeMonitor)
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.config import efficientnet_b0_config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.efficientnet import efficientnet_b0
|
||||
from src.loss import LabelSmoothingCrossEntropy
|
||||
|
||||
mindspore.common.set_seed(cfg.random_seed)
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
|
||||
|
||||
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
|
||||
decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0):
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
global_steps = steps_per_epoch * global_epoch
|
||||
self_warmup_delta = ((base_lr - warmup_lr_init) /
|
||||
warmup_steps) if warmup_steps > 0 else 0
|
||||
self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate
|
||||
for i in range(total_steps):
|
||||
steps = math.floor(i / steps_per_epoch)
|
||||
cond = 1 if (steps < warmup_steps) else 0
|
||||
warmup_lr = warmup_lr_init + steps * self_warmup_delta
|
||||
decay_nums = math.floor(steps / decay_steps)
|
||||
decay_rate = math.pow(self_decay_rate, decay_nums)
|
||||
decay_lr = base_lr * decay_rate
|
||||
lr = cond * warmup_lr + (1 - cond) * decay_lr
|
||||
lr_each_step.append(lr)
|
||||
lr_each_step = lr_each_step[global_steps:]
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def get_outdir(path, *paths, inc=False):
|
||||
outdir = os.path.join(path, *paths)
|
||||
if not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
elif inc:
|
||||
count = 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
while os.path.exists(outdir_inc):
|
||||
count = count + 1
|
||||
outdir_inc = outdir + '-' + str(count)
|
||||
assert count < 100
|
||||
outdir = outdir_inc
|
||||
os.makedirs(outdir)
|
||||
return outdir
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Training configuration', add_help=False)
|
||||
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--distributed', action='store_true', default=False)
|
||||
parser.add_argument('--GPU', action='store_true', default=False,
|
||||
help='Use GPU for training (default: False)')
|
||||
parser.add_argument('--cur_time', type=str,
|
||||
default='19701010-000000', help='current time')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
|
||||
|
||||
def main():
|
||||
args, _ = parser.parse_known_args()
|
||||
devid, rank_id, rank_size = 0, 0, 1
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
if args.distributed:
|
||||
if args.GPU:
|
||||
init("nccl")
|
||||
context.set_context(device_target='GPU')
|
||||
else:
|
||||
init()
|
||||
devid = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(
|
||||
device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
|
||||
context.reset_auto_parallel_context()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, device_num=rank_size)
|
||||
else:
|
||||
if args.GPU:
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
is_master = not args.distributed or (rank_id == 0)
|
||||
|
||||
net = efficientnet_b0(num_classes=cfg.num_classes,
|
||||
drop_rate=cfg.drop,
|
||||
drop_connect_rate=cfg.drop_connect,
|
||||
global_pool=cfg.gp,
|
||||
bn_tf=cfg.bn_tf,
|
||||
)
|
||||
|
||||
cur_time = args.cur_time
|
||||
output_base = './output'
|
||||
|
||||
exp_name = '-'.join([
|
||||
cur_time,
|
||||
cfg.model,
|
||||
str(224)
|
||||
])
|
||||
time.sleep(rank_id)
|
||||
output_dir = get_outdir(output_base, exp_name)
|
||||
|
||||
train_data_url = os.path.join(args.data_path, 'train')
|
||||
train_dataset = create_dataset(
|
||||
cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
|
||||
batches_per_epoch = train_dataset.get_dataset_size()
|
||||
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
|
||||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
loss_scale_manager = FixedLossScaleManager(
|
||||
cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(
|
||||
prefix=cfg.model, directory=output_dir, config=config_ck)
|
||||
|
||||
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
|
||||
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
|
||||
warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init,
|
||||
global_epoch=cfg.resume_start_epoch))
|
||||
if cfg.opt == 'sgd':
|
||||
optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=cfg.loss_scale
|
||||
)
|
||||
elif cfg.opt == 'rmsprop':
|
||||
optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale
|
||||
)
|
||||
|
||||
loss.add_flags_recursive(fp32=True, fp16=False)
|
||||
|
||||
if args.resume:
|
||||
ckpt = load_checkpoint(args.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
|
||||
model = Model(net, loss, optimizer,
|
||||
loss_scale_manager=loss_scale_manager,
|
||||
amp_level=cfg.amp_level
|
||||
)
|
||||
|
||||
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else []
|
||||
|
||||
if args.resume:
|
||||
real_epoch = cfg.epochs - cfg.resume_start_epoch
|
||||
model.train(real_epoch, train_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=True)
|
||||
else:
|
||||
model.train(cfg.epochs, train_dataset,
|
||||
callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue