diff --git a/model_zoo/official/cv/retinanet/README_CN.md b/model_zoo/official/cv/retinanet/README_CN.md index 989e8c6317..8de62c40d6 100644 --- a/model_zoo/official/cv/retinanet/README_CN.md +++ b/model_zoo/official/cv/retinanet/README_CN.md @@ -165,11 +165,11 @@ MSCOCO2017 # 八卡并行训练示例: 创建 RANK_TABLE_FILE -sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional) +sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR RANK_TABLE_FILE PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional) # 单卡训练示例: -sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional) +sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR PRE_TRAINED(optional) PRE_TRAINED_EPOCH_SIZE(optional) ``` @@ -182,6 +182,9 @@ sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) ```运行 # 训练示例 + 训练前,先创建MindRecord文件,以COCO数据集为例 + python create_data.py --dataset coco + python: data和存储mindrecord文件的路径在config里设置 @@ -193,12 +196,12 @@ sh run_distribute_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED(optional) # 八卡并行训练示例(在retinanet目录下运行): - sh scripts/run_distribute_train.sh 8 500 0.1 coco RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小) - 例如:sh scripts/run_distribute_train.sh 8 500 0.1 coco scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322 + sh scripts/run_distribute_train.sh 8 500 0.1 RANK_TABLE_FILE(创建的RANK_TABLE_FILE的地址) PRE_TRAINED(预训练checkpoint地址) PRE_TRAINED_EPOCH_SIZE(预训练EPOCH大小) + 例如:sh scripts/run_distribute_train.sh 8 500 0.1 scripts/rank_table_8pcs.json /dataset/retinanet-322_458.ckpt 322 # 单卡训练示例(在retinanet目录下运行): - sh scripts/run_single_train.sh 0 500 0.1 coco /dataset/retinanet-322_458.ckpt 322 + sh scripts/run_single_train.sh 0 500 0.1 /dataset/retinanet-322_458.ckpt 322 ``` #### 结果 diff --git a/model_zoo/official/cv/retinanet/create_data.py b/model_zoo/official/cv/retinanet/create_data.py new file mode 100644 index 0000000000..9c301aecdd --- /dev/null +++ b/model_zoo/official/cv/retinanet/create_data.py @@ -0,0 +1,25 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""create mindrecord for training retinanet.""" + +import argparse +from src.dataset import create_mindrecord + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="retinanet dataset create") + parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") + args_opt = parser.parse_args() + mindrecord_file = create_mindrecord(args_opt.dataset, "retinanet.mindrecord", True) diff --git a/model_zoo/official/cv/retinanet/export.py b/model_zoo/official/cv/retinanet/export.py new file mode 100644 index 0000000000..8271acad08 --- /dev/null +++ b/model_zoo/official/cv/retinanet/export.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export for retinanet""" +import argparse +import numpy as np +import mindspore.common.dtype as mstype +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export +from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder +from src.config import config +from src.box_utils import default_boxes + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='retinanet evaluation') + parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"), + help="run platform, only support Ascend.") + parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format") + parser.add_argument("--batch_size", type=int, default=1, help="batch size") + parser.add_argument("--file_name", type=str, default="retinanet", help="output file name.") + args_opt = parser.parse_args() + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id) + + backbone = resnet50(config.num_classes) + net = retinanet50(backbone, config) + net = retinanetInferWithDecoder(net, Tensor(default_boxes), config) + param_dict = load_checkpoint(config.checkpoint_path) + net.init_parameters_data() + load_param_into_net(net, param_dict) + net.set_train(False) + shape = [args_opt.batch_size, 3] + config.img_shape + input_data = Tensor(np.zeros(shape), mstype.float32) + export(net, input_data, file_name=args_opt.file_name, file_format=args_opt.file_format) diff --git a/model_zoo/official/cv/retinanet/scripts/run_distribute_train.sh b/model_zoo/official/cv/retinanet/scripts/run_distribute_train.sh index ebb3a80dee..efd860b697 100644 --- a/model_zoo/official/cv/retinanet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/retinanet/scripts/run_distribute_train.sh @@ -21,27 +21,24 @@ echo "for example: sh run_distribute_train.sh 8 500 0.1 coco /data/hccl.json /op echo "It is better to use absolute path." echo "=================================================================================================================" -if [ $# != 5 ] && [ $# != 7 ] +if [ $# != 4 ] && [ $# != 6 ] then - echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \ + echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] \ [RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" exit 1 fi -# Before start distribute train, first create mindrecord files. -BASE_PATH=$(cd "`dirname $0`" || exit; pwd) -cd $BASE_PATH/../ || exit -python train.py --only_create_dataset=True +core_num=`cat /proc/cpuinfo |grep "processor"|wc -l` +process_cores=$(($core_num/8)) echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt" export RANK_SIZE=$1 EPOCH_SIZE=$2 LR=$3 -DATASET=$4 -PRE_TRAINED=$6 -PRE_TRAINED_EPOCH_SIZE=$7 -export RANK_TABLE_FILE=$5 +PRE_TRAINED=$5 +PRE_TRAINED_EPOCH_SIZE=$6 +export RANK_TABLE_FILE=$4 for((i=0;i env.log - if [ $# == 5 ] + if [ $# == 4 ] then - python train.py \ + taskset -c $cmdopt python train.py \ + --workers=$process_cores \ --distribute=True \ --lr=$LR \ - --dataset=$DATASET \ --device_num=$RANK_SIZE \ --device_id=$DEVICE_ID \ --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & fi - if [ $# == 7 ] + if [ $# == 6 ] then - python train.py \ + taskset -c $cmdopt python train.py \ + --workers=$process_cores \ --distribute=True \ --lr=$LR \ - --dataset=$DATASET \ --device_num=$RANK_SIZE \ --device_id=$DEVICE_ID \ --pre_trained=$PRE_TRAINED \ diff --git a/model_zoo/official/cv/retinanet/scripts/run_single_train.sh b/model_zoo/official/cv/retinanet/scripts/run_single_train.sh index 9b96055386..53f998df05 100644 --- a/model_zoo/official/cv/retinanet/scripts/run_single_train.sh +++ b/model_zoo/official/cv/retinanet/scripts/run_single_train.sh @@ -8,6 +8,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -20,9 +21,9 @@ echo "for example: sh run_single_train.sh 0 500 0.1 coco /opt/retinanet-500_458. echo "It is better to use absolute path." echo "=================================================================================================================" -if [ $# != 4 ] && [ $# != 6 ] +if [ $# != 3 ] && [ $# != 5 ] then - echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \ + echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] \ [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" exit 1 fi @@ -37,9 +38,8 @@ echo "After running the script, the network runs in the background. The log will export DEVICE_ID=$1 EPOCH_SIZE=$2 LR=$3 -DATASET=$4 -PRE_TRAINED=$5 -PRE_TRAINED_EPOCH_SIZE=$6 +PRE_TRAINED=$4 +PRE_TRAINED_EPOCH_SIZE=$5 rm -rf LOG$1 mkdir ./LOG$1 @@ -48,23 +48,21 @@ cp -r ./src ./LOG$1 cd ./LOG$1 || exit echo "start training for device $1" env > env.log -if [ $# == 4 ] +if [ $# == 3 ] then python train.py \ --distribute=False \ --lr=$LR \ - --dataset=$DATASET \ --device_num=1 \ --device_id=$DEVICE_ID \ --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & fi -if [ $# == 6 ] +if [ $# == 5 ] then python train,py \ --distribute=False \ --lr=$LR \ - --dataset=$DATASET \ --device_num=1 \ --device_id=$DEVICE_ID \ --pre_trained=$PRE_TRAINED \ diff --git a/model_zoo/official/cv/retinanet/src/dataset.py b/model_zoo/official/cv/retinanet/src/dataset.py index f27b2b41c1..150a6857f5 100644 --- a/model_zoo/official/cv/retinanet/src/dataset.py +++ b/model_zoo/official/cv/retinanet/src/dataset.py @@ -389,7 +389,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="reti def create_retinanet_dataset(mindrecord_file, batch_size, repeat_num, device_num=1, rank=0, - is_training=True, num_parallel_workers=64): + is_training=True, num_parallel_workers=24): """Creatr retinanet dataset with MindDataset.""" ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num, shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training) diff --git a/model_zoo/official/cv/retinanet/src/retinanet.py b/model_zoo/official/cv/retinanet/src/retinanet.py index b0102a8bff..6e9c4f312b 100644 --- a/model_zoo/official/cv/retinanet/src/retinanet.py +++ b/model_zoo/official/cv/retinanet/src/retinanet.py @@ -251,9 +251,15 @@ class retinanetWithLossCell(nn.Cell): self.expand_dims = P.ExpandDims() self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha) self.loc_loss = nn.SmoothL1Loss() + self.cast = P.Cast() + + self.network.to_float(mstype.float16) def construct(self, x, gt_loc, gt_label, num_matched_boxes): pred_loc, pred_label = self.network(x) + pred_loc = self.cast(pred_loc, mstype.float32) + pred_label = self.cast(pred_label, mstype.float32) + mask = F.cast(self.less(0, gt_label), mstype.float32) num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32)) diff --git a/model_zoo/official/cv/retinanet/train.py b/model_zoo/official/cv/retinanet/train.py index 33a8828ab8..ec463a4c8f 100644 --- a/model_zoo/official/cv/retinanet/train.py +++ b/model_zoo/official/cv/retinanet/train.py @@ -18,7 +18,6 @@ import os import argparse import ast -import mindspore import mindspore.nn as nn from mindspore import context, Tensor from mindspore.communication.management import init, get_rank @@ -29,7 +28,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed from src.retinanet import retinanetWithLossCell, TrainingWrapper, retinanet50, resnet50 from src.config import config -from src.dataset import create_retinanet_dataset, create_mindrecord +from src.dataset import create_retinanet_dataset from src.lr_schedule import get_lr from src.init_params import init_net_param, filter_checkpoint_parameter @@ -59,15 +58,14 @@ class Monitor(Callback): def main(): parser = argparse.ArgumentParser(description="retinanet training") - parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False, - help="If set it true, only create Mindrecord, default is False.") + parser.add_argument("--distribute", type=ast.literal_eval, default=False, help="Run distribute, default is False.") + parser.add_argument("--workers", type=int, default=24, help="Num parallel workers.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") - parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") @@ -98,56 +96,55 @@ def main(): else: raise ValueError("Unsupported platform.") - mindrecord_file = create_mindrecord(args_opt.dataset, "retinanet.mindrecord", True) - - if not args_opt.only_create_dataset: - loss_scale = float(args_opt.loss_scale) - - # When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0. - dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1, - batch_size=args_opt.batch_size, device_num=device_num, rank=rank) - - dataset_size = dataset.get_dataset_size() - print("Create dataset done!") - - - backbone = resnet50(config.num_classes) - retinanet = retinanet50(backbone, config) - net = retinanetWithLossCell(retinanet, config) - net.to_float(mindspore.float16) - init_net_param(net) - - if args_opt.pre_trained: - if args_opt.pre_trained_epoch_size <= 0: - raise KeyError("pre_trained_epoch_size must be greater than 0.") - param_dict = load_checkpoint(args_opt.pre_trained) - if args_opt.filter_weight: - filter_checkpoint_parameter(param_dict) - load_param_into_net(net, param_dict) - - lr = Tensor(get_lr(global_step=config.global_step, - lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, - warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2, - warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4, - warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size, - steps_per_epoch=dataset_size)) - opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, - config.momentum, config.weight_decay, loss_scale) - net = TrainingWrapper(net, opt, loss_scale) - model = Model(net) - print("Start train retinanet, the first epoch will be slower because of the graph compilation.") - cb = [TimeMonitor(), LossMonitor()] - cb += [Monitor(lr_init=lr.asnumpy())] - config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs, - keep_checkpoint_max=config.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck) - if args_opt.distribute: - if rank == 0: - cb += [ckpt_cb] - model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) - else: + mindrecord_file = os.path.join(config.mindrecord_dir, "retinanet.mindrecord0") + + loss_scale = float(args_opt.loss_scale) + + # When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0. + dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1, + num_parallel_workers=args_opt.workers, + batch_size=args_opt.batch_size, device_num=device_num, rank=rank) + + dataset_size = dataset.get_dataset_size() + print("Create dataset done!") + + + backbone = resnet50(config.num_classes) + retinanet = retinanet50(backbone, config) + net = retinanetWithLossCell(retinanet, config) + init_net_param(net) + + if args_opt.pre_trained: + if args_opt.pre_trained_epoch_size <= 0: + raise KeyError("pre_trained_epoch_size must be greater than 0.") + param_dict = load_checkpoint(args_opt.pre_trained) + if args_opt.filter_weight: + filter_checkpoint_parameter(param_dict) + load_param_into_net(net, param_dict) + + lr = Tensor(get_lr(global_step=config.global_step, + lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr, + warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2, + warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4, + warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size, + steps_per_epoch=dataset_size)) + opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, + config.momentum, config.weight_decay, loss_scale) + net = TrainingWrapper(net, opt, loss_scale) + model = Model(net) + print("Start train retinanet, the first epoch will be slower because of the graph compilation.") + cb = [TimeMonitor(), LossMonitor()] + cb += [Monitor(lr_init=lr.asnumpy())] + config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs, + keep_checkpoint_max=config.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck) + if args_opt.distribute: + if rank == 0: cb += [ckpt_cb] - model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) + model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) + else: + cb += [ckpt_cb] + model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True) if __name__ == '__main__': main()