# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Face Quality Assessment train.""" import os import time import datetime import argparse import warnings import numpy as np import mindspore from mindspore import context from mindspore import Tensor from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig from mindspore.nn import TrainOneStepCell from mindspore.nn.optim import Momentum from mindspore.communication.management import get_group_size, init, get_rank from src.loss import CriterionsFaceQA from src.config import faceqa_1p_cfg, faceqa_8p_cfg from src.face_qa import FaceQABackbone, BuildTrainNetwork from src.lr_generator import warmup_step from src.dataset import faceqa_dataset from src.log import get_logger, AverageMeter warnings.filterwarnings('ignore') devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid) mindspore.common.seed.set_seed(1) def main(args): if args.is_distributed == 0: cfg = faceqa_1p_cfg else: cfg = faceqa_8p_cfg cfg.data_lst = args.train_label_file cfg.pretrained = args.pretrained # Init distributed if args.is_distributed: init() cfg.local_rank = get_rank() cfg.world_size = get_group_size() parallel_mode = ParallelMode.DATA_PARALLEL else: parallel_mode = ParallelMode.STAND_ALONE # parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size, gradients_mean=True) mindspore.common.set_seed(1) # logger cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank) loss_meter = AverageMeter('loss') # Dataloader cfg.logger.info('start create dataloader') de_dataset = faceqa_dataset(imlist=cfg.data_lst, local_rank=cfg.local_rank, world_size=cfg.world_size, per_batch_size=cfg.per_batch_size) cfg.steps_per_epoch = de_dataset.get_dataset_size() de_dataset = de_dataset.repeat(cfg.max_epoch) de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True) # Show cfg cfg.logger.save_args(cfg) cfg.logger.info('end create dataloader') # backbone and loss cfg.logger.important_info('start create network') create_network_start = time.time() network = FaceQABackbone() criterion = CriterionsFaceQA() # load pretrain model if os.path.isfile(cfg.pretrained): param_dict = load_checkpoint(cfg.pretrained) param_dict_new = {} for key, values in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): param_dict_new[key[8:]] = values else: param_dict_new[key] = values load_param_into_net(network, param_dict_new) cfg.logger.info('load model {} success'.format(cfg.pretrained)) # optimizer and lr scheduler lr = warmup_step(cfg, gamma=0.9) opt = Momentum(params=network.trainable_params(), learning_rate=lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay, loss_scale=cfg.loss_scale) # package training process, adjust lr + forward + backward + optimizer train_net = BuildTrainNetwork(network, criterion) train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale,) # checkpoint save if cfg.local_rank == 0: ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num) ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank)) cb_params = _InternalCallbackParam() cb_params.train_network = train_net cb_params.epoch_num = ckpt_max_num cb_params.cur_epoch_num = 1 run_context = RunContext(cb_params) ckpt_cb.begin(run_context) train_net.set_train() t_end = time.time() t_epoch = time.time() old_progress = -1 cfg.logger.important_info('====start train====') for i, (data, gt) in enumerate(de_dataloader): # clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss data = data.astype(np.float32) gt = gt.astype(np.float32) data = Tensor(data) gt = Tensor(gt) loss = train_net(data, gt) loss_meter.update(loss.asnumpy()) # ckpt if cfg.local_rank == 0: cb_params.cur_step_num = i + 1 # current step number cb_params.batch_num = i + 2 ckpt_cb.step_end(run_context) # logging loss, fps, ... if i == 0: time_for_graph_compile = time.time() - create_network_start cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile)) if i % cfg.log_interval == 0 and cfg.local_rank == 0: time_used = time.time() - t_end epoch = int(i / cfg.steps_per_epoch) fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps)) t_end = time.time() loss_meter.reset() old_progress = i if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0: epoch_time_used = time.time() - t_epoch epoch = int(i / cfg.steps_per_epoch) fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used cfg.logger.info('=================================================') cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps)) cfg.logger.info('=================================================') t_epoch = time.time() cfg.logger.important_info('====train end====') if __name__ == "__main__": parser = argparse.ArgumentParser(description='Face Quality Assessment') parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') parser.add_argument('--train_label_file', type=str, default='', help='image label list file, e.g. /home/label.txt') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') arg = parser.parse_args() main(arg)