You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/cv/FaceAttribute/train.py

233 lines
8.4 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Face attribute train."""
import os
import time
import datetime
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.communication.management import get_group_size, init, get_rank
from mindspore.nn import TrainOneStepCell
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, RunContext, _InternalCallbackParam, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
from src.FaceAttribute.resnet18 import get_resnet18
from src.FaceAttribute.loss_factory import get_loss
from src.dataset_train import data_generator
from src.lrsche_factory import warmup_step
from src.logging import get_logger, AverageMeter
from src.config import config
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=devid)
class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.print = P.Print()
def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label)
return loss
def parse_args():
'''Argument for Face Attributes.'''
parser = argparse.ArgumentParser('Face Attributes')
parser.add_argument('--mindrecord_path', type=str, default='', help='dataset path, e.g. /home/data.mindrecord')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load')
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
args, _ = parser.parse_known_args()
return args
def train():
'''train function.'''
# logger
args = parse_args()
# init distributed
if args.world_size != 1:
init()
args.local_rank = get_rank()
args.world_size = get_group_size()
args.per_batch_size = config.per_batch_size
args.dst_h = config.dst_h
args.dst_w = config.dst_w
args.workers = config.workers
args.attri_num = config.attri_num
args.classes = config.classes
args.backbone = config.backbone
args.loss_scale = config.loss_scale
args.flat_dim = config.flat_dim
args.fc_dim = config.fc_dim
args.lr = config.lr
args.lr_scale = config.lr_scale
args.lr_epochs = config.lr_epochs
args.weight_decay = config.weight_decay
args.momentum = config.momentum
args.max_epoch = config.max_epoch
args.warmup_epochs = config.warmup_epochs
args.log_interval = config.log_interval
args.ckpt_path = config.ckpt_path
if args.world_size == 1:
args.per_batch_size = 256
else:
args.lr = args.lr * 4.
if args.world_size != 1:
parallel_mode = ParallelMode.DATA_PARALLEL
else:
parallel_mode = ParallelMode.STAND_ALONE
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=args.world_size)
# model and log save path
args.outputs_dir = os.path.join(args.ckpt_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.local_rank)
loss_meter = AverageMeter('loss')
# dataloader
args.logger.info('start create dataloader')
de_dataloader, steps_per_epoch, num_classes = data_generator(args)
args.steps_per_epoch = steps_per_epoch
args.num_classes = num_classes
args.logger.info('end create dataloader')
args.logger.save_args(args)
# backbone and loss
args.logger.important_info('start create network')
create_network_start = time.time()
network = get_resnet18(args)
criterion = get_loss()
# load pretrain model
if os.path.isfile(args.pretrained):
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(args.pretrained))
# optimizer and lr scheduler
lr = warmup_step(args, gamma=0.1)
opt = Momentum(params=network.trainable_params(),
learning_rate=lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
train_net = BuildTrainNetwork(network, criterion)
# mixed precision training
criterion.add_flags_recursive(fp32=True)
# package training process
train_net = TrainOneStepCell(train_net, opt, sens=args.loss_scale)
context.reset_auto_parallel_context()
# checkpoint
if args.local_rank == 0:
ckpt_max_num = args.max_epoch
train_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch, keep_checkpoint_max=ckpt_max_num)
ckpt_cb = ModelCheckpoint(config=train_config, directory=args.outputs_dir, prefix='{}'.format(args.local_rank))
cb_params = _InternalCallbackParam()
cb_params.train_network = train_net
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 0
run_context = RunContext(cb_params)
ckpt_cb.begin(run_context)
train_net.set_train()
t_end = time.time()
t_epoch = time.time()
old_progress = -1
i = 0
for _, (data, gt_classes) in enumerate(de_dataloader):
data_tensor = Tensor(data, dtype=mstype.float32)
gt_tensor = Tensor(gt_classes, dtype=mstype.int32)
loss = train_net(data_tensor, gt_tensor)
loss_meter.update(loss.asnumpy()[0])
# save ckpt
if args.local_rank == 0:
cb_params.cur_step_num = i + 1
cb_params.batch_num = i + 2
ckpt_cb.step_end(run_context)
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
cb_params.cur_epoch_num += 1
# save Log
if i == 0:
time_for_graph_compile = time.time() - create_network_start
args.logger.important_info('{}, graph compile time={:.2f}s'.format(args.backbone, time_for_graph_compile))
if i % args.log_interval == 0 and args.local_rank == 0:
time_used = time.time() - t_end
epoch = int(i / args.steps_per_epoch)
fps = args.per_batch_size * (i - old_progress) * args.world_size / time_used
args.logger.info('epoch[{}], iter[{}], {}, {:.2f} imgs/sec'.format(epoch, i, loss_meter, fps))
t_end = time.time()
loss_meter.reset()
old_progress = i
if i % args.steps_per_epoch == 0 and args.local_rank == 0:
epoch_time_used = time.time() - t_epoch
epoch = int(i / args.steps_per_epoch)
fps = args.per_batch_size * args.world_size * args.steps_per_epoch / epoch_time_used
args.logger.info('=================================================')
args.logger.info('epoch time: epoch[{}], iter[{}], {:.2f} imgs/sec'.format(epoch, i, fps))
args.logger.info('=================================================')
t_epoch = time.time()
i += 1
args.logger.info('--------- trains out ---------')
if __name__ == "__main__":
train()