You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
7.5 KiB
188 lines
7.5 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Face Quality Assessment train."""
|
|
import os
|
|
import time
|
|
import datetime
|
|
import argparse
|
|
import warnings
|
|
import numpy as np
|
|
|
|
import mindspore
|
|
from mindspore import context
|
|
from mindspore import Tensor
|
|
from mindspore.context import ParallelMode
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
|
|
from mindspore.nn import TrainOneStepCell
|
|
from mindspore.nn.optim import Momentum
|
|
from mindspore.communication.management import get_group_size, init, get_rank
|
|
|
|
from src.loss import CriterionsFaceQA
|
|
from src.config import faceqa_1p_cfg, faceqa_8p_cfg
|
|
from src.face_qa import FaceQABackbone, BuildTrainNetwork
|
|
from src.lr_generator import warmup_step
|
|
from src.dataset import faceqa_dataset
|
|
from src.log import get_logger, AverageMeter
|
|
|
|
warnings.filterwarnings('ignore')
|
|
devid = int(os.getenv('DEVICE_ID'))
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
|
|
mindspore.common.seed.set_seed(1)
|
|
|
|
def main(args):
|
|
|
|
if args.is_distributed == 0:
|
|
cfg = faceqa_1p_cfg
|
|
else:
|
|
cfg = faceqa_8p_cfg
|
|
|
|
cfg.data_lst = args.train_label_file
|
|
cfg.pretrained = args.pretrained
|
|
|
|
# Init distributed
|
|
if args.is_distributed:
|
|
init()
|
|
cfg.local_rank = get_rank()
|
|
cfg.world_size = get_group_size()
|
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
|
else:
|
|
parallel_mode = ParallelMode.STAND_ALONE
|
|
|
|
# parallel_mode 'STAND_ALONE' do not support parameter_broadcast and mirror_mean
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.world_size,
|
|
gradients_mean=True)
|
|
|
|
mindspore.common.set_seed(1)
|
|
|
|
# logger
|
|
cfg.outputs_dir = os.path.join(cfg.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
|
cfg.logger = get_logger(cfg.outputs_dir, cfg.local_rank)
|
|
loss_meter = AverageMeter('loss')
|
|
|
|
# Dataloader
|
|
cfg.logger.info('start create dataloader')
|
|
de_dataset = faceqa_dataset(imlist=cfg.data_lst, local_rank=cfg.local_rank, world_size=cfg.world_size,
|
|
per_batch_size=cfg.per_batch_size)
|
|
cfg.steps_per_epoch = de_dataset.get_dataset_size()
|
|
de_dataset = de_dataset.repeat(cfg.max_epoch)
|
|
de_dataloader = de_dataset.create_tuple_iterator(output_numpy=True)
|
|
# Show cfg
|
|
cfg.logger.save_args(cfg)
|
|
cfg.logger.info('end create dataloader')
|
|
|
|
# backbone and loss
|
|
cfg.logger.important_info('start create network')
|
|
create_network_start = time.time()
|
|
|
|
network = FaceQABackbone()
|
|
criterion = CriterionsFaceQA()
|
|
|
|
# load pretrain model
|
|
if os.path.isfile(cfg.pretrained):
|
|
param_dict = load_checkpoint(cfg.pretrained)
|
|
param_dict_new = {}
|
|
for key, values in param_dict.items():
|
|
if key.startswith('moments.'):
|
|
continue
|
|
elif key.startswith('network.'):
|
|
param_dict_new[key[8:]] = values
|
|
else:
|
|
param_dict_new[key] = values
|
|
load_param_into_net(network, param_dict_new)
|
|
cfg.logger.info('load model {} success'.format(cfg.pretrained))
|
|
|
|
# optimizer and lr scheduler
|
|
lr = warmup_step(cfg, gamma=0.9)
|
|
opt = Momentum(params=network.trainable_params(),
|
|
learning_rate=lr,
|
|
momentum=cfg.momentum,
|
|
weight_decay=cfg.weight_decay,
|
|
loss_scale=cfg.loss_scale)
|
|
|
|
# package training process, adjust lr + forward + backward + optimizer
|
|
train_net = BuildTrainNetwork(network, criterion)
|
|
train_net = TrainOneStepCell(train_net, opt, sens=cfg.loss_scale,)
|
|
|
|
# checkpoint save
|
|
if cfg.local_rank == 0:
|
|
ckpt_max_num = cfg.max_epoch * cfg.steps_per_epoch // cfg.ckpt_interval
|
|
train_config = CheckpointConfig(save_checkpoint_steps=cfg.ckpt_interval, keep_checkpoint_max=ckpt_max_num)
|
|
ckpt_cb = ModelCheckpoint(config=train_config, directory=cfg.outputs_dir, prefix='{}'.format(cfg.local_rank))
|
|
cb_params = _InternalCallbackParam()
|
|
cb_params.train_network = train_net
|
|
cb_params.epoch_num = ckpt_max_num
|
|
cb_params.cur_epoch_num = 1
|
|
run_context = RunContext(cb_params)
|
|
ckpt_cb.begin(run_context)
|
|
|
|
train_net.set_train()
|
|
t_end = time.time()
|
|
t_epoch = time.time()
|
|
old_progress = -1
|
|
|
|
cfg.logger.important_info('====start train====')
|
|
for i, (data, gt) in enumerate(de_dataloader):
|
|
# clean grad + adjust lr + put data into device + forward + backward + optimizer, return loss
|
|
data = data.astype(np.float32)
|
|
gt = gt.astype(np.float32)
|
|
data = Tensor(data)
|
|
gt = Tensor(gt)
|
|
|
|
loss = train_net(data, gt)
|
|
loss_meter.update(loss.asnumpy())
|
|
|
|
# ckpt
|
|
if cfg.local_rank == 0:
|
|
cb_params.cur_step_num = i + 1 # current step number
|
|
cb_params.batch_num = i + 2
|
|
ckpt_cb.step_end(run_context)
|
|
|
|
# logging loss, fps, ...
|
|
if i == 0:
|
|
time_for_graph_compile = time.time() - create_network_start
|
|
cfg.logger.important_info('{}, graph compile time={:.2f}s'.format(cfg.task, time_for_graph_compile))
|
|
|
|
if i % cfg.log_interval == 0 and cfg.local_rank == 0:
|
|
time_used = time.time() - t_end
|
|
epoch = int(i / cfg.steps_per_epoch)
|
|
fps = cfg.per_batch_size * (i - old_progress) * cfg.world_size / time_used
|
|
cfg.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
|
|
t_end = time.time()
|
|
loss_meter.reset()
|
|
old_progress = i
|
|
|
|
if i % cfg.steps_per_epoch == 0 and cfg.local_rank == 0:
|
|
epoch_time_used = time.time() - t_epoch
|
|
epoch = int(i / cfg.steps_per_epoch)
|
|
fps = cfg.per_batch_size * cfg.world_size * cfg.steps_per_epoch / epoch_time_used
|
|
cfg.logger.info('=================================================')
|
|
cfg.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
|
|
cfg.logger.info('=================================================')
|
|
t_epoch = time.time()
|
|
|
|
cfg.logger.important_info('====train end====')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='Face Quality Assessment')
|
|
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
|
|
parser.add_argument('--train_label_file', type=str, default='', help='image label list file, e.g. /home/label.txt')
|
|
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
|
|
|
|
arg = parser.parse_args()
|
|
|
|
main(arg)
|