You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/cv/resnext101/train.py

260 lines
9.7 KiB

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train ImageNet."""
import os
import time
import argparse
import datetime
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, Callback
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.common import set_seed
from src.dataset import classification_dataset
from src.crossentropy import CrossEntropy
from src.lr_generator import get_lr
from src.utils.logging import get_logger
from src.utils.optimizers__init__ import get_param_groups
from src.utils.var_init import load_pretrain_model
from src.image_classification import get_network
from src.config import config
set_seed(1)
class BuildTrainNetwork(nn.Cell):
"""build training network"""
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss = self.criterion(output, label)
return loss
class ProgressMonitor(Callback):
"""monitor loss and time"""
def __init__(self, args):
super(ProgressMonitor, self).__init__()
self.me_epoch_start_time = 0
self.me_epoch_start_step_num = 0
self.args = args
self.ckpt_history = []
def begin(self, run_context):
self.args.logger.info('start network train...')
def epoch_begin(self, run_context):
pass
def epoch_end(self, run_context, *me_args):
cb_params = run_context.original_args()
me_step = cb_params.cur_step_num - 1
real_epoch = me_step // self.args.steps_per_epoch
time_used = time.time() - self.me_epoch_start_time
fps_mean = self.args.per_batch_size * (me_step-self.me_epoch_start_step_num) * self.args.group_size / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}'
'imgs/sec'.format(real_epoch, me_step, cb_params.net_outputs, fps_mean))
if self.args.rank_save_ckpt_flag:
import glob
ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt'))
for ckpt in ckpts:
ckpt_fn = os.path.basename(ckpt)
if not ckpt_fn.startswith('{}-'.format(self.args.rank)):
continue
if ckpt in self.ckpt_history:
continue
self.ckpt_history.append(ckpt)
self.args.logger.info('epoch[{}], iter[{}], loss:{}, ckpt:{},'
'ckpt_fn:{}'.format(real_epoch, me_step, cb_params.net_outputs, ckpt, ckpt_fn))
self.me_epoch_start_step_num = me_step
self.me_epoch_start_time = time.time()
def step_begin(self, run_context):
pass
def step_end(self, run_context, *me_args):
pass
def end(self, run_context):
self.args.logger.info('end network train...')
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore classification training')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained', default='', type=str, help='model_path, local pretrained model to load')
# distributed related
parser.add_argument('--is_distributed', action="store_true", default=False, help='if multi device')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config.image_size
args.num_classes = config.num_classes
args.lr = config.lr
args.lr_scheduler = config.lr_scheduler
args.lr_epochs = config.lr_epochs
args.lr_gamma = config.lr_gamma
args.eta_min = config.eta_min
args.T_max = config.T_max
args.max_epoch = config.max_epoch
args.warmup_epochs = config.warmup_epochs
args.weight_decay = config.weight_decay
args.momentum = config.momentum
args.is_dynamic_loss_scale = config.is_dynamic_loss_scale
args.loss_scale = config.loss_scale
args.label_smooth = config.label_smooth
args.label_smooth_factor = config.label_smooth_factor
args.ckpt_interval = config.ckpt_interval
args.ckpt_save_max = config.ckpt_save_max
args.ckpt_path = config.ckpt_path
args.is_save_on_master = config.is_save_on_master
args.rank = config.rank
args.group_size = config.group_size
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
args.image_size = list(map(int, args.image_size.split(',')))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.platform, save_graphs=False)
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
else:
args.rank = 0
args.group_size = 1
if args.is_dynamic_loss_scale == 1:
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
return args
def merge_args(args, cloud_args):
"""dictionary"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def train(cloud_args=None):
"""training process"""
args = parse_args(cloud_args)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# init distributed
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
gradients_mean=True)
# dataloader
de_dataset = classification_dataset(args.data_dir, args.image_size,
args.per_batch_size, 1,
args.rank, args.group_size, num_parallel_workers=8)
de_dataset.map_model = 4 # !!!important
args.steps_per_epoch = de_dataset.get_dataset_size()
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# get network and init
network = get_network(num_classes=args.num_classes, platform=args.platform)
load_pretrain_model(args.pretrained, network, args)
# lr scheduler
lr = get_lr(args)
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
# loss
if not args.label_smooth:
args.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
if args.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O3")
# checkpoint save
progress_cb = ProgressMonitor(args)
callbacks = [progress_cb,]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
if __name__ == "__main__":
train()