diff --git a/model_zoo/official/cv/yolov3_darknet53/README.md b/model_zoo/official/cv/yolov3_darknet53/README.md index bc2ba7ec3c..9181e11607 100644 --- a/model_zoo/official/cv/yolov3_darknet53/README.md +++ b/model_zoo/official/cv/yolov3_darknet53/README.md @@ -75,7 +75,6 @@ python train.py \ --is_distributed=0 \ --lr=0.001 \ --loss_scale=1024 \ - --sens=1024 \ --weight_decay=0.016 \ --T_max=320 \ --max_epoch=320 \ @@ -175,8 +174,6 @@ optional arguments: Whether to use label smooth in CE. Default:0 --label_smooth_factor LABEL_SMOOTH_FACTOR Smooth strength of original one-hot. Default: 0.1 - --sens SENS - Static sens. Default: 1024 --log_interval LOG_INTERVAL Logging interval steps. Default: 100 --ckpt_path CKPT_PATH @@ -211,7 +208,6 @@ python train.py \ --is_distributed=0 \ --lr=0.001 \ --loss_scale=1024 \ - --sens=1024 \ --weight_decay=0.016 \ --T_max=320 \ --max_epoch=320 \ diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index 76239047ba..679f84da30 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -124,20 +124,6 @@ def parse_args(): args.data_root = os.path.join(args.data_dir, 'train2014') args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') - # select for master rank save ckpt or all rank save, compatiable for model parallel - args.rank_save_ckpt_flag = 0 - if args.is_save_on_master: - if args.rank == 0: - args.rank_save_ckpt_flag = 1 - else: - args.rank_save_ckpt_flag = 1 - - # logger - args.outputs_dir = os.path.join(args.ckpt_path, - datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - args.logger = get_logger(args.outputs_dir, args.rank) - args.logger.save_args(args) - return args @@ -160,6 +146,20 @@ def train(): init("nccl") args.rank = get_rank() args.group_size = get_group_size() + # select for master rank save ckpt or all rank save, compatiable for model parallel + args.rank_save_ckpt_flag = 0 + if args.is_save_on_master: + if args.rank == 0: + args.rank_save_ckpt_flag = 1 + else: + args.rank_save_ckpt_flag = 1 + + # logger + args.outputs_dir = os.path.join(args.ckpt_path, + datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) + args.logger = get_logger(args.outputs_dir, args.rank) + args.logger.save_args(args) + if args.need_profiler: from mindspore.profiler.profiling import Profiler profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)