# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Face attribute train.""" import os import time import datetime import argparse import mindspore.nn as nn from mindspore import context from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.communication.management import get_group_size, init, get_rank from mindspore.nn import TrainOneStepCell from mindspore.context import ParallelMode from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.ops import operations as P from mindspore.common import dtype as mstype from src.FaceAttribute.resnet18 import get_resnet18 from src.FaceAttribute.loss_factory import get_loss from src.dataset_train import data_generator from src.lrsche_factory import warmup_step from src.logging import get_logger, AverageMeter from src.config import config devid = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid) class BuildTrainNetwork(nn.Cell): '''Build train network.''' def __init__(self, network, criterion): super(BuildTrainNetwork, self).__init__() self.network = network self.criterion = criterion self.print = P.Print() def construct(self, input_data, label): logit0, logit1, logit2 = self.network(input_data) loss = self.criterion(logit0, logit1, logit2, label) return loss def parse_args(): '''Argument for Face Attributes.''' parser = argparse.ArgumentParser('Face Attributes') parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord') parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed') parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed') args, _ = parser.parse_known_args() return args def train(): '''train function.''' # logger args = parse_args() # init distributed if args.world_size != 1: init() args.local_rank = get_rank() args.world_size = get_group_size() args.per_batch_size = config.per_batch_size args.dst_h = config.dst_h args.dst_w = config.dst_w args.workers = config.workers args.attri_num = config.attri_num args.classes = config.classes args.backbone = config.backbone args.loss_scale = config.loss_scale args.flat_dim = config.flat_dim args.fc_dim = config.fc_dim args.lr = config.lr args.lr_scale = config.lr_scale args.lr_epochs = config.lr_epochs args.weight_decay = config.weight_decay args.momentum = config.momentum args.max_epoch = config.max_epoch args.warmup_epochs = config.warmup_epochs args.log_interval = config.log_interval args.ckpt_path = config.ckpt_path if args.world_size == 1: args.per_batch_size = 256 else: args.lr = args.lr * 4. if args.world_size != 1: parallel_mode = ParallelMode.DATA_PARALLEL else: parallel_mode = ParallelMode.STAND_ALONE context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size) # model and log save path args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) args.logger = get_logger(args.outputs_dir, args.local_rank) loss_meter = AverageMeter('loss') # dataloader args.logger.info('start create dataloader') de_dataloader, steps_per_epoch, num_classes = data_generator(args) args.steps_per_epoch = steps_per_epoch args.num_classes = num_classes args.logger.info('end create dataloader') args.logger.save_args(args) # backbone and loss args.logger.important_info('start create network') create_network_start = time.time() network = get_resnet18(args) criterion = get_loss() # load pretrain model if os.path.isfile(args.pretrained): param_dict = load_checkpoint(args.pretrained) param_dict_new = {} for key, values in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): param_dict_new[key[8:]] = values else: param_dict_new[key] = values load_param_into_net(network, param_dict_new) args.logger.info('load model {} success'.format(args.pretrained)) # optimizer and lr scheduler lr = warmup_step(args, gamma=0.1) opt = Momentum(params=network.trainable_params(), learning_rate=lr, momentum=args.momentum, weight_decay=args.weight_decay, loss_scale=args.loss_scale) train_net = BuildTrainNetwork(network, criterion) # mixed precision training criterion.add_flags_recursive(fp32=True) # package training process train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale) context.reset_auto_parallel_context() # checkpoint if args.local_rank == 0: ckpt_max_num = args.max_epoch train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num) ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank)) cb_params = _InternalCallbackParam() cb_params.train_network = train_net cb_params.epoch_num = ckpt_max_num cb_params.cur_epoch_num = 0 run_context = RunContext(cb_params) ckpt_cb.begin(run_context) train_net.set_train() t_end = time.time() t_epoch = time.time() old_progress = -1 i = 0 for _, (data, gt_classes) in enumerate(de_dataloader): data_tensor = Tensor(data, dtype=mstype.float32) gt_tensor = Tensor(gt_classes, dtype=mstype.int32) loss = train_net(data_tensor, gt_tensor) loss_meter.update(loss.asnumpy()[0]) # save ckpt if args.local_rank == 0: cb_params.cur_step_num = i + 1 cb_params.batch_num = i + 2 ckpt_cb.step_end(run_context) if i % args.steps_per_epoch == 0 and args.local_rank == 0: cb_params.cur_epoch_num += 1 # save Log if i == 0: time_for_graph_compile = time.time() - create_network_start args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile)) if i % args.log_interval == 0 and args.local_rank == 0: time_used = time.time() - t_end epoch = int(i / args.steps_per_epoch) fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps)) t_end = time.time() loss_meter.reset() old_progress = i if i % args.steps_per_epoch == 0 and args.local_rank == 0: epoch_time_used = time.time() - t_epoch epoch = int(i / args.steps_per_epoch) fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used args.logger.info('=================================================') args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps)) args.logger.info('=================================================') t_epoch = time.time() i += 1 args.logger.info('--------- trains out ---------') if __name__ == "__main__": train()