!7320 add GPU efficientnet-b0 to modelzoo

Merge pull request !7320 from 34bunny/GPU-efficientnet
pull/7320/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 359543d663

@ -0,0 +1,111 @@
# EfficientNet-B0 Example
## Description
This is an example of training EfficientNet-B0 in MindSpore.
## Requirements
- Install [Mindspore](http://www.mindspore.cn/install/en).
- Download the dataset.
## Structure
```shell
.
└─nasnet
├─README.md
├─scripts
├─run_standalone_train_for_gpu.sh # launch standalone training with gpu platform(1p)
├─run_distribute_train_for_gpu.sh # launch distributed training with gpu platform(8p)
└─run_eval_for_gpu.sh # launch evaluating with gpu platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─efficientnet.py # network definition
├─loss.py # Customized loss function
├─transform_utils.py # random augment utils
├─transform.py # random augment class
├─eval.py # eval net
└─train.py # train net
```
## Parameter Configuration
Parameters for both training and evaluating can be set in config.py
```
'random_seed': 1, # fix random seed
'model': 'efficientnet_b0', # model name
'drop': 0.2, # dropout rate
'drop_connect': 0.2, # drop connect rate
'opt_eps': 0.001, # optimizer epsilon
'lr': 0.064, # learning rate LR
'batch_size': 128, # batch size
'decay_epochs': 2.4, # epoch interval to decay LR
'warmup_epochs': 5, # epochs to warmup LR
'decay_rate': 0.97, # LR decay rate
'weight_decay': 1e-5, # weight decay
'epochs': 600, # number of epochs to train
'workers': 8, # number of data processing processes
'amp_level': 'O0', # amp level
'opt': 'rmsprop', # optimizer
'num_classes': 1000, # number of classes
'gp': 'avg', # type of global pool, "avg", "max", "avgmax", "avgmaxc"
'momentum': 0.9, # optimizer momentum
'warmup_lr_init': 0.0001, # init warmup LR
'smoothing': 0.1, # label smoothing factor
'bn_tf': False, # use Tensorflow BatchNorm defaults
'keep_checkpoint_max': 10, # max number ckpts to keep
'loss_scale': 1024, # loss scale
'resume_start_epoch': 0, # resume start epoch
```
## Running the example
### Train
#### Usage
```
# distribute training example(8p)
sh run_distribute_train_for_gpu.sh DATA_DIR
# standalone training
sh run_standalone_train_for_gpu.sh DATA_DIR DEVICE_ID
```
#### Launch
```bash
# distributed training example(8p) for GPU
sh scripts/run_distribute_train_for_gpu.sh /dataset
# standalone training example for GPU
sh scripts/run_standalone_train_for_gpu.sh /dataset 0
```
#### Result
You can find checkpoint file together with result in log.
### Evaluation
#### Usage
```
# Evaluation
sh run_eval_for_gpu.sh DATA_DIR DEVICE_ID PATH_CHECKPOINT
```
#### Launch
```bash
# Evaluation with checkpoint
sh scripts/run_eval_for_gpu.sh /dataset 0 ./checkpoint/efficientnet_b0-600_1251.ckpt
```
> checkpoint can be produced in training process.
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluate imagenet"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import efficientnet_b0_config_gpu as cfg
from src.dataset import create_dataset_val
from src.efficientnet import efficientnet_b0
from src.loss import LabelSmoothingCrossEntropy
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of efficientnet (Default: None)')
parser.add_argument('--data_path', type=str, default='', help='Dataset path')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
args_opt = parser.parse_args()
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
net = efficientnet_b0(num_classes=cfg.num_classes,
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
bn_tf=cfg.bn_tf,
)
ckpt = load_checkpoint(args_opt.checkpoint)
load_param_into_net(net, ckpt)
net.set_train(False)
val_data_url = os.path.join(args_opt.data_path, 'val')
dataset = create_dataset_val(cfg.batch_size, val_data_url, workers=cfg.workers, distributed=False)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(dataset)
print("metric: ", metrics)

@ -0,0 +1,32 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
current_exec_path=$(pwd)
echo ${current_exec_path}
curtime=`date '+%Y%m%d-%H%M%S'`
RANK_SIZE=8
rm ${current_exec_path}/device_parallel/ -rf
mkdir ${current_exec_path}/device_parallel
echo ${curtime} > ${current_exec_path}/device_parallel/starttime
mpirun --allow-run-as-root -n $RANK_SIZE python ${current_exec_path}/train.py \
--GPU \
--distributed \
--data_path ${DATA_DIR} \
--cur_time ${curtime} > ${current_exec_path}/device_parallel/efficientnet_b0.log 2>&1 &

@ -0,0 +1,27 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
DEVICE_ID=$2
PATH_CHECKPOINT=$3
current_exec_path=$(pwd)
echo ${current_exec_path}
curtime=`date '+%Y%m%d-%H%M%S'`
echo ${curtime} > ${current_exec_path}/eval_starttime
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ./eval.py --platform 'GPU' --data_path ${DATA_DIR} --checkpoint ${PATH_CHECKPOINT} > ${current_exec_path}/eval.log 2>&1 &

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
DEVICE_ID=$2
current_exec_path=$(pwd)
echo ${current_exec_path}
curtime=`date '+%Y%m%d-%H%M%S'`
rm ${current_exec_path}/device_${DEVICE_ID}/ -rf
mkdir ${current_exec_path}/device_${DEVICE_ID}
echo ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/starttime
CUDA_VISIBLE_DEVICES=${DEVICE_ID} python ${current_exec_path}/train.py \
--GPU \
--data_path ${DATA_DIR} \
--cur_time ${curtime} > ${current_exec_path}/device_${DEVICE_ID}/efficientnet_b0.log 2>&1 &

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting
"""
from easydict import EasyDict as edict
efficientnet_b0_config_gpu = edict({
'random_seed': 1,
'model': 'efficientnet_b0',
'drop': 0.2,
'drop_connect': 0.2,
'opt_eps': 0.001,
'lr': 0.064,
'batch_size': 128,
'decay_epochs': 2.4,
'warmup_epochs': 5,
'decay_rate': 0.97,
'weight_decay': 1e-5,
'epochs': 600,
'workers': 8,
'amp_level': 'O0',
'opt': 'rmsprop',
'num_classes': 1000,
#'Type of global pool, "avg", "max", "avgmax", "avgmaxc"
'gp': 'avg',
'momentum': 0.9,
'warmup_lr_init': 0.0001,
'smoothing': 0.1,
#Use Tensorflow BatchNorm defaults for models that support it
'bn_tf': False,
'keep_checkpoint_max': 10,
'loss_scale': 1024,
'resume_start_epoch': 0,
})

@ -0,0 +1,125 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import math
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
from mindspore.communication.management import get_group_size, get_rank
from mindspore.dataset.vision import Inter
from src.config import efficientnet_b0_config_gpu as cfg
from src.transform import RandAugment
ds.config.set_seed(cfg.random_seed)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
img_size = (224, 224)
crop_pct = 0.875
rescale = 1.0 / 255.0
shift = 0.0
inter_method = 'bilinear'
resize_value = 224 # img_size
scale = (0.08, 1.0)
ratio = (3./4., 4./3.)
inter_str = 'bicubic'
def str2MsInter(method):
if method == 'bicubic':
return Inter.BICUBIC
if method == 'nearest':
return Inter.NEAREST
return Inter.BILINEAR
def create_dataset(batch_size, train_data_url='', workers=8, distributed=False):
if not os.path.exists(train_data_url):
raise ValueError('Path not exists')
interpolation = str2MsInter(inter_str)
c_decode_op = C.Decode()
type_cast_op = C2.TypeCast(mstype.int32)
random_resize_crop_op = C.RandomResizedCrop(size=(resize_value, resize_value), scale=scale, ratio=ratio,
interpolation=interpolation)
random_horizontal_flip_op = C.RandomHorizontalFlip(0.5)
efficient_rand_augment = RandAugment()
image_ops = [c_decode_op, random_resize_crop_op, random_horizontal_flip_op]
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset_train = ds.ImageFolderDataset(train_data_url,
num_parallel_workers=workers,
shuffle=True,
num_shards=rank_size,
shard_id=rank_id)
dataset_train = dataset_train.map(input_columns=["image"],
operations=image_ops,
num_parallel_workers=workers)
dataset_train = dataset_train.map(input_columns=["label"],
operations=type_cast_op,
num_parallel_workers=workers)
ds_train = dataset_train.batch(batch_size,
per_batch_map=efficient_rand_augment,
input_columns=["image", "label"],
num_parallel_workers=2,
drop_remainder=True)
ds_train = ds_train.repeat(1)
return ds_train
def create_dataset_val(batch_size=128, val_data_url='', workers=8, distributed=False):
if not os.path.exists(val_data_url):
raise ValueError('Path not exists')
rank_id = get_rank() if distributed else 0
rank_size = get_group_size() if distributed else 1
dataset = ds.ImageFolderDataset(val_data_url, num_parallel_workers=workers,
num_shards=rank_size, shard_id=rank_id, shuffle=False)
scale_size = None
interpolation = str2MsInter(inter_method)
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
type_cast_op = C2.TypeCast(mstype.int32)
decode_op = C.Decode()
resize_op = C.Resize(size=scale_size, interpolation=interpolation)
center_crop = C.CenterCrop(size=224)
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
changeswap_op = C.HWC2CHW()
ctrans = [decode_op, resize_op, center_crop, rescale_op, normalize_op, changeswap_op]
dataset = dataset.map(input_columns=["label"], operations=type_cast_op, num_parallel_workers=workers)
dataset = dataset.map(input_columns=["image"], operations=ctrans, num_parallel_workers=workers)
dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=workers)
dataset = dataset.repeat(1)
return dataset

File diff suppressed because it is too large Load Diff

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore import Tensor
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logits, label):
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss_logit = self.ce(logits, one_hot_label)
loss_logit = self.mean(loss_logit, 0)
return loss_logit

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
random augment class
"""
import numpy as np
import mindspore.dataset.vision.py_transforms as P
from src import transform_utils
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
class RandAugment:
# config_str belongs to str
# hparams belongs to dict
def __init__(self, config_str="rand-m9-mstd0.5", hparams=None):
hparams = hparams if hparams is not None else {}
self.config_str = config_str
self.hparams = hparams
def __call__(self, imgs, labels, batchInfo):
# assert the imgs objetc are pil_images
ret_imgs = []
ret_labels = []
py_to_pil_op = P.ToPIL()
to_tensor = P.ToTensor()
normalize_op = P.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
rand_augment_ops = transform_utils.rand_augment_transform(self.config_str, self.hparams)
for i, image in enumerate(imgs):
img_pil = py_to_pil_op(image)
img_pil = rand_augment_ops(img_pil)
img_array = to_tensor(img_pil)
img_array = normalize_op(img_array)
ret_imgs.append(img_array)
ret_labels.append(labels[i])
return np.array(ret_imgs), np.array(ret_labels)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,191 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train imagenet."""
import argparse
import math
import os
import random
import time
import numpy as np
import mindspore
from mindspore import Tensor, context
from mindspore.communication.management import get_group_size, get_rank, init
from mindspore.nn import SGD, RMSProp
from mindspore.train.callback import (CheckpointConfig, LossMonitor,
ModelCheckpoint, TimeMonitor)
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import efficientnet_b0_config_gpu as cfg
from src.dataset import create_dataset
from src.efficientnet import efficientnet_b0
from src.loss import LabelSmoothingCrossEntropy
mindspore.common.set_seed(cfg.random_seed)
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
def get_lr(base_lr, total_epochs, steps_per_epoch, decay_steps=1,
decay_rate=0.9, warmup_steps=0., warmup_lr_init=0., global_epoch=0):
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
global_steps = steps_per_epoch * global_epoch
self_warmup_delta = ((base_lr - warmup_lr_init) /
warmup_steps) if warmup_steps > 0 else 0
self_decay_rate = decay_rate if decay_rate < 1 else 1 / decay_rate
for i in range(total_steps):
steps = math.floor(i / steps_per_epoch)
cond = 1 if (steps < warmup_steps) else 0
warmup_lr = warmup_lr_init + steps * self_warmup_delta
decay_nums = math.floor(steps / decay_steps)
decay_rate = math.pow(self_decay_rate, decay_nums)
decay_lr = base_lr * decay_rate
lr = cond * warmup_lr + (1 - cond) * decay_lr
lr_each_step.append(lr)
lr_each_step = lr_each_step[global_steps:]
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):
os.makedirs(outdir)
elif inc:
count = 1
outdir_inc = outdir + '-' + str(count)
while os.path.exists(outdir_inc):
count = count + 1
outdir_inc = outdir + '-' + str(count)
assert count < 100
outdir = outdir_inc
os.makedirs(outdir)
return outdir
parser = argparse.ArgumentParser(
description='Training configuration', add_help=False)
parser.add_argument('--data_path', type=str, default='/home/dataset/imagenet_jpeg/', metavar='DIR',
help='path to dataset')
parser.add_argument('--distributed', action='store_true', default=False)
parser.add_argument('--GPU', action='store_true', default=False,
help='Use GPU for training (default: False)')
parser.add_argument('--cur_time', type=str,
default='19701010-000000', help='current time')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
def main():
args, _ = parser.parse_known_args()
devid, rank_id, rank_size = 0, 0, 1
context.set_context(mode=context.GRAPH_MODE)
if args.distributed:
if args.GPU:
init("nccl")
context.set_context(device_target='GPU')
else:
init()
devid = int(os.getenv('DEVICE_ID'))
context.set_context(
device_target='Ascend', device_id=devid, reserve_class_name_in_scope=False)
context.reset_auto_parallel_context()
rank_id = get_rank()
rank_size = get_group_size()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, device_num=rank_size)
else:
if args.GPU:
context.set_context(device_target='GPU')
is_master = not args.distributed or (rank_id == 0)
net = efficientnet_b0(num_classes=cfg.num_classes,
drop_rate=cfg.drop,
drop_connect_rate=cfg.drop_connect,
global_pool=cfg.gp,
bn_tf=cfg.bn_tf,
)
cur_time = args.cur_time
output_base = './output'
exp_name = '-'.join([
cur_time,
cfg.model,
str(224)
])
time.sleep(rank_id)
output_dir = get_outdir(output_base, exp_name)
train_data_url = os.path.join(args.data_path, 'train')
train_dataset = create_dataset(
cfg.batch_size, train_data_url, workers=cfg.workers, distributed=args.distributed)
batches_per_epoch = train_dataset.get_dataset_size()
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
loss = LabelSmoothingCrossEntropy(smooth_factor=cfg.smoothing)
time_cb = TimeMonitor(data_size=batches_per_epoch)
loss_scale_manager = FixedLossScaleManager(
cfg.loss_scale, drop_overflow_update=False)
config_ck = CheckpointConfig(
save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(
prefix=cfg.model, directory=output_dir, config=config_ck)
lr = Tensor(get_lr(base_lr=cfg.lr, total_epochs=cfg.epochs, steps_per_epoch=batches_per_epoch,
decay_steps=cfg.decay_epochs, decay_rate=cfg.decay_rate,
warmup_steps=cfg.warmup_epochs, warmup_lr_init=cfg.warmup_lr_init,
global_epoch=cfg.resume_start_epoch))
if cfg.opt == 'sgd':
optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale
)
elif cfg.opt == 'rmsprop':
optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=cfg.weight_decay,
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale
)
loss.add_flags_recursive(fp32=True, fp16=False)
if args.resume:
ckpt = load_checkpoint(args.resume)
load_param_into_net(net, ckpt)
model = Model(net, loss, optimizer,
loss_scale_manager=loss_scale_manager,
amp_level=cfg.amp_level
)
callbacks = [loss_cb, ckpoint_cb, time_cb] if is_master else []
if args.resume:
real_epoch = cfg.epochs - cfg.resume_start_epoch
model.train(real_epoch, train_dataset,
callbacks=callbacks, dataset_sink_mode=True)
else:
model.train(cfg.epochs, train_dataset,
callbacks=callbacks, dataset_sink_mode=True)
if __name__ == '__main__':
main()
Loading…
Cancel
Save