You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/cv/FaceRecognitionForTracking/train.py

198 lines
7.9 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Recognition train."""
import os
import time
import argparse
import datetime
import warnings
import random
import numpy as np
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
from mindspore.nn.optim import SGD
from mindspore.nn import TrainOneStepCell
from mindspore.communication.management import get_group_size, init, get_rank
from src.dataset import get_de_dataset
from src.config import reid_1p_cfg, reid_8p_cfg
from src.lr_generator import step_lr
from src.log import get_logger, AverageMeter
from src.reid import SphereNet, CombineMarginFCFp16, BuildTrainNetworkWithHead
from src.loss import CrossEntropy
warnings.filterwarnings('ignore')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid)
random.seed(1)
np.random.seed(1)
def main():
parser = argparse.ArgumentParser(description='Cifar10 classification')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--data_dir', type=str, default='', help='image label list file, e.g. /home/label.txt')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
args = parser.parse_args()
if args.is_distributed == 0:
cfg = reid_1p_cfg
else:
cfg = reid_8p_cfg
cfg.pretrained = args.pretrained
cfg.data_dir = args.data_dir
# Init distributed
if args.is_distributed:
init()
cfg.local_rank = get_rank()
cfg.world_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
# parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
gradients_mean=True)
mindspore.common.set_seed(1)
# logger
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
loss_meter = AverageMeter('loss')
# Show cfg
cfg.logger.save_args(cfg)
# dataloader
cfg.logger.info('start create dataloader')
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
cfg.steps_per_epoch = steps_per_epoch
cfg.logger.info('step per epoch: ' + str(cfg.steps_per_epoch))
de_dataloader = de_dataset.create_tuple_iterator()
cfg.logger.info('class num original: ' + str(class_num))
if class_num % 16 != 0:
class_num = (class_num // 16 + 1) * 16
cfg.class_num = class_num
cfg.logger.info('change the class num to :' + str(cfg.class_num))
cfg.logger.info('end create dataloader')
# backbone and loss
cfg.logger.important_info('start create network')
create_network_start = time.time()
network = SphereNet(num_layers=cfg.net_depth, feature_dim=cfg.embedding_size, shape=cfg.input_size)
head = CombineMarginFCFp16(embbeding_size=cfg.embedding_size, classnum=cfg.class_num)
criterion = CrossEntropy()
# load the pretrained model
if os.path.isfile(cfg.pretrained):
param_dict = load_checkpoint(cfg.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
cfg.logger.info('load model {} success'.format(cfg.pretrained))
# mixed precision training
network.add_flags_recursive(fp16=True)
head.add_flags_recursive(fp16=True)
criterion.add_flags_recursive(fp32=True)
train_net = BuildTrainNetworkWithHead(network, head, criterion)
# optimizer and lr scheduler
lr = step_lr(lr=cfg.lr, epoch_size=cfg.epoch_size, steps_per_epoch=cfg.steps_per_epoch, max_epoch=cfg.max_epoch,
gamma=cfg.lr_gamma)
opt = SGD(params=train_net.trainable_params(), learning_rate=lr, momentum=cfg.momentum,
weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale)
# package training process, adjust lr + forward + backward + optimizer
train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale)
# checkpoint save
if cfg.local_rank == 0:
ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = train_net
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
train_net.set_train()
t_end = time.time()
t_epoch = time.time()
old_progress = -1
cfg.logger.important_info('====start train====')
for i, total_data in enumerate(de_dataloader):
data, gt = total_data
data = Tensor(data)
gt = Tensor(gt)
loss = train_net(data, gt)
loss_meter.update(loss.asnumpy())
# ckpt
if cfg.local_rank == 0:
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
# logging loss, fps, ...
if i == 0:
time_for_graph_compile = time.time() - create_network_start
cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile))
if i % cfg.log_interval == 0 and cfg.local_rank == 0:
time_used = time.time() - t_end
epoch = int(i / cfg.steps_per_epoch)
fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used
cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec, lr={}'.format(epoch, i, loss_meter, fps, lr[i]))
t_end = time.time()
loss_meter.reset()
old_progress = i
if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
epoch_time_used = time.time() - t_epoch
epoch = int(i / cfg.steps_per_epoch)
fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
cfg.logger.info('=================================================')
cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
cfg.logger.info('=================================================')
t_epoch = time.time()
cfg.logger.important_info('====train end====')
if __name__ == "__main__":
main()