# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Face Recognition train.""" import os import argparse import mindspore from mindspore.nn import Cell from mindspore import context from mindspore.context import ParallelMode from mindspore.communication.management import get_group_size, init, get_rank from mindspore.nn.optim import Momentum from mindspore.train.model import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.loss_scale_manager import DynamicLossScaleManager from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.config import config_base, config_beta from src.my_logging import get_logger from src.init_network import init_net from src.dataset_factory import get_de_dataset from src.backbone.resnet import get_backbone from src.metric_factory import get_metric_fc from src.loss_factory import get_loss from src.lrsche_factory import warmup_step_list, list_to_gen from src.callback_factory import ProgressMonitor mindspore.common.seed.set_seed(1) devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid, reserve_class_name_in_scope=False, enable_auto_mixed_precision=False) class DistributedHelper(Cell): '''DistributedHelper''' def __init__(self, backbone, margin_fc): super(DistributedHelper, self).__init__() self.backbone = backbone self.margin_fc = margin_fc if margin_fc is not None: self.has_margin_fc = 1 else: self.has_margin_fc = 0 def construct(self, x, label): embeddings = self.backbone(x) if self.has_margin_fc == 1: return embeddings, self.margin_fc(embeddings, label) return embeddings class BuildTrainNetwork(Cell): '''BuildTrainNetwork''' def __init__(self, network, criterion, args_1): super(BuildTrainNetwork, self).__init__() self.network = network self.criterion = criterion self.args = args_1 if int(args_1.model_parallel) == 0: self.is_model_parallel = 0 else: self.is_model_parallel = 1 def construct(self, input_data, label): if self.is_model_parallel == 0: _, output = self.network(input_data, label) loss = self.criterion(output, label) else: _ = self.network(input_data, label) loss = self.criterion(None, label) return loss def parse_args(): parser = argparse.ArgumentParser('MindSpore Face Recognition') parser.add_argument('--train_stage', type=str, default='base', help='train stage, base or beta') parser.add_argument('--is_distributed', type=int, default=1, help='if multi device') args_opt_1, _ = parser.parse_known_args() return args_opt_1 if __name__ == "__main__": args_opt = parse_args() support_train_stage = ['base', 'beta'] if args_opt.train_stage.lower() not in support_train_stage: args.logger.info('support train stage is:{}, while yours is:{}'. format(support_train_stage, args_opt.train_stage)) raise ValueError('train stage not support.') args = config_base if args_opt.train_stage.lower() == 'base' else config_beta args.is_distributed = args_opt.is_distributed if args_opt.is_distributed: init() args.local_rank = get_rank() args.world_size = get_group_size() parallel_mode = ParallelMode.HYBRID_PARALLEL else: parallel_mode = ParallelMode.STAND_ALONE context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.world_size, gradients_mean=True) if not os.path.exists(args.data_dir): args.logger.info('ERROR, data_dir is not exists, please set data_dir in config.py') raise ValueError('ERROR, data_dir is not exists, please set data_dir in config.py') args.lr_epochs = list(map(int, args.lr_epochs.split(','))) log_path = os.path.join(args.ckpt_path, 'logs') args.logger = get_logger(log_path, args.local_rank) if args.local_rank % 8 == 0: if not os.path.exists(args.ckpt_path): os.makedirs(args.ckpt_path) args.logger.info('args.world_size:{}'.format(args.world_size)) args.logger.info('args.local_rank:{}'.format(args.local_rank)) args.logger.info('args.lr:{}'.format(args.lr)) momentum = args.momentum weight_decay = args.weight_decay de_dataset, steps_per_epoch, num_classes = get_de_dataset(args) args.logger.info('de_dataset:{}'.format(de_dataset.get_dataset_size())) args.steps_per_epoch = steps_per_epoch args.num_classes = num_classes args.logger.info('loaded, nums: {}'.format(args.num_classes)) if args.nc_16 == 1: if args.model_parallel == 0: if args.num_classes % 16 == 0: args.logger.info('data parallel aleardy 16, nums: {}'.format(args.num_classes)) else: args.num_classes = (args.num_classes // 16 + 1) * 16 else: if args.num_classes % (args.world_size * 16) == 0: args.logger.info('model parallel aleardy 16, nums: {}'.format(args.num_classes)) else: args.num_classes = (args.num_classes // (args.world_size * 16) + 1) * args.world_size * 16 args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes)) args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch)) args.logger.info('img_total_num:{}'.format(args.steps_per_epoch * args.per_batch_size)) args.logger.info('get_backbone----in----') _backbone = get_backbone(args) args.logger.info('get_backbone----out----') args.logger.info('get_metric_fc----in----') margin_fc_1 = get_metric_fc(args) args.logger.info('get_metric_fc----out----') args.logger.info('DistributedHelper----in----') network_1 = DistributedHelper(_backbone, margin_fc_1) args.logger.info('DistributedHelper----out----') args.logger.info('network fp16----in----') if args.fp16 == 1: network_1.add_flags_recursive(fp16=True) args.logger.info('network fp16----out----') criterion_1 = get_loss(args) if args.fp16 == 1 and args.model_parallel == 0: criterion_1.add_flags_recursive(fp32=True) if os.path.isfile(args.pretrained): param_dict = load_checkpoint(args.pretrained) param_dict_new = {} if args_opt.train_stage.lower() == 'base': for key, value in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): param_dict_new[key[8:]] = value else: for key, value in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): if 'layers.' in key and 'bn1' in key: continue elif 'se' in key: continue elif 'head' in key: continue elif 'margin_fc.weight' in key: continue else: param_dict_new[key[8:]] = value load_param_into_net(network_1, param_dict_new) args.logger.info('load model {} success'.format(args.pretrained)) else: init_net(args, network_1) train_net = BuildTrainNetwork(network_1, criterion_1, args) args.logger.info('args:{}'.format(args)) # call warmup_step should behind the args steps_per_epoch args.lrs = warmup_step_list(args, gamma=0.1) lrs_gen = list_to_gen(args.lrs) opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=momentum, weight_decay=weight_decay) scale_manager = DynamicLossScaleManager(init_loss_scale=args.dynamic_init_loss_scale, scale_factor=2, scale_window=2000) model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager) save_checkpoint_steps = args.ckpt_steps args.logger.info('save_checkpoint_steps:{}'.format(save_checkpoint_steps)) if args.max_ckpts == -1: keep_checkpoint_max = int(args.steps_per_epoch * args.max_epoch / save_checkpoint_steps) + 5 # for more than 5 else: keep_checkpoint_max = args.max_ckpts args.logger.info('keep_checkpoint_max:{}'.format(keep_checkpoint_max)) ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max) max_epoch_train = args.max_epoch args.logger.info('max_epoch_train:{}'.format(max_epoch_train)) ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=args.ckpt_path, prefix='{}'.format(args.local_rank)) args.epoch_cnt = 0 progress_cb = ProgressMonitor(args) new_epoch_train = max_epoch_train * steps_per_epoch // args.log_interval model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=args.log_interval)