You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/cv/FaceQualityAssessment/train.py

188 lines
7.5 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face Quality Assessment train."""
import os
import time
import datetime
import argparse
import warnings
import numpy as np
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
from mindspore.nn import TrainOneStepCell
from mindspore.nn.optim import Momentum
from mindspore.communication.management import get_group_size, init, get_rank
from src.loss import CriterionsFaceQA
from src.config import faceqa_1p_cfg, faceqa_8p_cfg
from src.face_qa import FaceQABackbone, BuildTrainNetwork
from src.lr_generator import warmup_step
from src.dataset import faceqa_dataset
from src.log import get_logger, AverageMeter
warnings.filterwarnings('ignore')
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
mindspore.common.seed.set_seed(1)
def main(args):
if args.is_distributed == 0:
cfg = faceqa_1p_cfg
else:
cfg = faceqa_8p_cfg
cfg.data_lst = args.train_label_file
cfg.pretrained = args.pretrained
# Init distributed
if args.is_distributed:
init()
cfg.local_rank = get_rank()
cfg.world_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
# parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
gradients_mean=True)
mindspore.common.set_seed(1)
# logger
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
loss_meter = AverageMeter('loss')
# Dataloader
cfg.logger.info('start create dataloader')
de_dataset = faceqa_dataset(imlist=cfg.data_lst, local_rank=cfg.local_rank, world_size=cfg.world_size,
per_batch_size=cfg.per_batch_size)
cfg.steps_per_epoch = de_dataset.get_dataset_size()
de_dataset = de_dataset.repeat(cfg.max_epoch)
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
# Show cfg
cfg.logger.save_args(cfg)
cfg.logger.info('end create dataloader')
# backbone and loss
cfg.logger.important_info('start create network')
create_network_start = time.time()
network = FaceQABackbone()
criterion = CriterionsFaceQA()
# load pretrain model
if os.path.isfile(cfg.pretrained):
param_dict = load_checkpoint(cfg.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
cfg.logger.info('load model {} success'.format(cfg.pretrained))
# optimizer and lr scheduler
lr = warmup_step(cfg, gamma=0.9)
opt = Momentum(params=network.trainable_params(),
learning_rate=lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
# package training process, adjust lr + forward + backward + optimizer
train_net = BuildTrainNetwork(network, criterion)
train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale,)
# checkpoint save
if cfg.local_rank == 0:
ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = train_net
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
train_net.set_train()
t_end = time.time()
t_epoch = time.time()
old_progress = -1
cfg.logger.important_info('====start train====')
for i, (data, gt) in enumerate(de_dataloader):
# clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss
data = data.astype(np.float32)
gt = gt.astype(np.float32)
data = Tensor(data)
gt = Tensor(gt)
loss = train_net(data, gt)
loss_meter.update(loss.asnumpy())
# ckpt
if cfg.local_rank == 0:
cb_params.cur_step_num = i + 1 # current step number
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
# logging loss, fps, ...
if i == 0:
time_for_graph_compile = time.time() - create_network_start
cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile))
if i % cfg.log_interval == 0 and cfg.local_rank == 0:
time_used = time.time() - t_end
epoch = int(i / cfg.steps_per_epoch)
fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used
cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
t_end = time.time()
loss_meter.reset()
old_progress = i
if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
epoch_time_used = time.time() - t_epoch
epoch = int(i / cfg.steps_per_epoch)
fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
cfg.logger.info('=================================================')
cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
cfg.logger.info('=================================================')
t_epoch = time.time()
cfg.logger.important_info('====train end====')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Face Quality Assessment')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--train_label_file', type=str, default='', help='image label list file, e.g. /home/label.txt')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
arg = parser.parse_args()
main(arg)