You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
236 lines
9.6 KiB
236 lines
9.6 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Face Recognition train."""
|
|
import os
|
|
import argparse
|
|
|
|
import mindspore
|
|
from mindspore.nn import Cell
|
|
from mindspore import context
|
|
from mindspore.context import ParallelMode
|
|
from mindspore.communication.management import get_group_size, init, get_rank
|
|
from mindspore.nn.optim import Momentum
|
|
from mindspore.train.model import Model
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
from src.config import config_base, config_beta
|
|
from src.my_logging import get_logger
|
|
from src.init_network import init_net
|
|
from src.dataset_factory import get_de_dataset
|
|
from src.backbone.resnet import get_backbone
|
|
from src.metric_factory import get_metric_fc
|
|
from src.loss_factory import get_loss
|
|
from src.lrsche_factory import warmup_step_list, list_to_gen
|
|
from src.callback_factory import ProgressMonitor
|
|
|
|
mindspore.common.seed.set_seed(1)
|
|
devid = int(os.getenv('DEVICE_ID'))
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
|
device_id=devid, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False)
|
|
|
|
class DistributedHelper(Cell):
|
|
'''DistributedHelper'''
|
|
def __init__(self, backbone, margin_fc):
|
|
super(DistributedHelper, self).__init__()
|
|
self.backbone = backbone
|
|
self.margin_fc = margin_fc
|
|
if margin_fc is not None:
|
|
self.has_margin_fc = 1
|
|
else:
|
|
self.has_margin_fc = 0
|
|
|
|
def construct(self, x, label):
|
|
embeddings = self.backbone(x)
|
|
if self.has_margin_fc == 1:
|
|
return embeddings, self.margin_fc(embeddings, label)
|
|
return embeddings
|
|
|
|
|
|
class BuildTrainNetwork(Cell):
|
|
'''BuildTrainNetwork'''
|
|
def __init__(self, network, criterion, args_1):
|
|
super(BuildTrainNetwork, self).__init__()
|
|
self.network = network
|
|
self.criterion = criterion
|
|
self.args = args_1
|
|
|
|
if int(args_1.model_parallel) == 0:
|
|
self.is_model_parallel = 0
|
|
else:
|
|
self.is_model_parallel = 1
|
|
|
|
def construct(self, input_data, label):
|
|
|
|
if self.is_model_parallel == 0:
|
|
_, output = self.network(input_data, label)
|
|
loss = self.criterion(output, label)
|
|
else:
|
|
_ = self.network(input_data, label)
|
|
loss = self.criterion(None, label)
|
|
|
|
return loss
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser('MindSpore Face Recognition')
|
|
parser.add_argument('--train_stage', type=str, default='base', help='train stage, base or beta')
|
|
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
|
|
|
|
args_opt_1, _ = parser.parse_known_args()
|
|
return args_opt_1
|
|
|
|
if __name__ == "__main__":
|
|
args_opt = parse_args()
|
|
|
|
support_train_stage = ['base', 'beta']
|
|
if args_opt.train_stage.lower() not in support_train_stage:
|
|
args.logger.info('support train stage is:{}, while yours is:{}'.
|
|
format(support_train_stage, args_opt.train_stage))
|
|
raise ValueError('train stage not support.')
|
|
args = config_base if args_opt.train_stage.lower() == 'base' else config_beta
|
|
args.is_distributed = args_opt.is_distributed
|
|
if args_opt.is_distributed:
|
|
init()
|
|
args.local_rank = get_rank()
|
|
args.world_size = get_group_size()
|
|
parallel_mode = ParallelMode.HYBRID_PARALLEL
|
|
else:
|
|
parallel_mode = ParallelMode.STAND_ALONE
|
|
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
|
device_num=args.world_size, gradients_mean=True)
|
|
|
|
if not os.path.exists(args.data_dir):
|
|
args.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py')
|
|
raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py')
|
|
|
|
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
|
|
|
|
|
|
log_path = os.path.join(args.ckpt_path, 'logs')
|
|
args.logger = get_logger(log_path, args.local_rank)
|
|
|
|
if args.local_rank % 8 == 0:
|
|
if not os.path.exists(args.ckpt_path):
|
|
os.makedirs(args.ckpt_path)
|
|
|
|
args.logger.info('args.world_size:{}'.format(args.world_size))
|
|
args.logger.info('args.local_rank:{}'.format(args.local_rank))
|
|
args.logger.info('args.lr:{}'.format(args.lr))
|
|
|
|
momentum = args.momentum
|
|
weight_decay = args.weight_decay
|
|
|
|
de_dataset, steps_per_epoch, num_classes = get_de_dataset(args)
|
|
args.logger.info('de_dataset:{}'.format(de_dataset.get_dataset_size()))
|
|
args.steps_per_epoch = steps_per_epoch
|
|
args.num_classes = num_classes
|
|
|
|
args.logger.info('loaded, nums: {}'.format(args.num_classes))
|
|
if args.nc_16 == 1:
|
|
if args.model_parallel == 0:
|
|
if args.num_classes % 16 == 0:
|
|
args.logger.info('data parallel aleardy 16, nums: {}'.format(args.num_classes))
|
|
else:
|
|
args.num_classes = (args.num_classes // 16 + 1) * 16
|
|
else:
|
|
if args.num_classes % (args.world_size * 16) == 0:
|
|
args.logger.info('model parallel aleardy 16, nums: {}'.format(args.num_classes))
|
|
else:
|
|
args.num_classes = (args.num_classes // (args.world_size * 16) + 1) * args.world_size * 16
|
|
|
|
args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes))
|
|
args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch))
|
|
args.logger.info('img_total_num:{}'.format(args.steps_per_epoch * args.per_batch_size))
|
|
|
|
args.logger.info('get_backbone----in----')
|
|
_backbone = get_backbone(args)
|
|
args.logger.info('get_backbone----out----')
|
|
|
|
args.logger.info('get_metric_fc----in----')
|
|
margin_fc_1 = get_metric_fc(args)
|
|
args.logger.info('get_metric_fc----out----')
|
|
|
|
args.logger.info('DistributedHelper----in----')
|
|
network_1 = DistributedHelper(_backbone, margin_fc_1)
|
|
args.logger.info('DistributedHelper----out----')
|
|
|
|
args.logger.info('network fp16----in----')
|
|
if args.fp16 == 1:
|
|
network_1.add_flags_recursive(fp16=True)
|
|
args.logger.info('network fp16----out----')
|
|
|
|
criterion_1 = get_loss(args)
|
|
if args.fp16 == 1 and args.model_parallel == 0:
|
|
criterion_1.add_flags_recursive(fp32=True)
|
|
|
|
if os.path.isfile(args.pretrained):
|
|
param_dict = load_checkpoint(args.pretrained)
|
|
param_dict_new = {}
|
|
if args_opt.train_stage.lower() == 'base':
|
|
for key, value in param_dict.items():
|
|
if key.startswith('moments.'):
|
|
continue
|
|
elif key.startswith('network.'):
|
|
param_dict_new[key[8:]] = value
|
|
else:
|
|
for key, value in param_dict.items():
|
|
if key.startswith('moments.'):
|
|
continue
|
|
elif key.startswith('network.'):
|
|
if 'layers.' in key and 'bn1' in key:
|
|
continue
|
|
elif 'se' in key:
|
|
continue
|
|
elif 'head' in key:
|
|
continue
|
|
elif 'margin_fc.weight' in key:
|
|
continue
|
|
else:
|
|
param_dict_new[key[8:]] = value
|
|
load_param_into_net(network_1, param_dict_new)
|
|
args.logger.info('load model {} success'.format(args.pretrained))
|
|
else:
|
|
init_net(args, network_1)
|
|
|
|
train_net = BuildTrainNetwork(network_1, criterion_1, args)
|
|
|
|
args.logger.info('args:{}'.format(args))
|
|
# call warmup_step should behind the args steps_per_epoch
|
|
args.lrs = warmup_step_list(args, gamma=0.1)
|
|
lrs_gen = list_to_gen(args.lrs)
|
|
opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=momentum,
|
|
weight_decay=weight_decay)
|
|
scale_manager = DynamicLossScaleManager(init_loss_scale=args.dynamic_init_loss_scale, scale_factor=2,
|
|
scale_window=2000)
|
|
model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager)
|
|
save_checkpoint_steps = args.ckpt_steps
|
|
args.logger.info('save_checkpoint_steps:{}'.format(save_checkpoint_steps))
|
|
if args.max_ckpts == -1:
|
|
keep_checkpoint_max = int(args.steps_per_epoch * args.max_epoch / save_checkpoint_steps) + 5 # for more than 5
|
|
else:
|
|
keep_checkpoint_max = args.max_ckpts
|
|
args.logger.info('keep_checkpoint_max:{}'.format(keep_checkpoint_max))
|
|
|
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max)
|
|
max_epoch_train = args.max_epoch
|
|
args.logger.info('max_epoch_train:{}'.format(max_epoch_train))
|
|
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.ckpt_path, prefix='{}'.format(args.local_rank))
|
|
args.epoch_cnt = 0
|
|
progress_cb = ProgressMonitor(args)
|
|
new_epoch_train = max_epoch_train * steps_per_epoch // args.log_interval
|
|
model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=args.log_interval)
|