diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index 5f7eca4fc4..76239047ba 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -124,15 +124,6 @@ def parse_args(): args.data_root = os.path.join(args.data_dir, 'train2014') args.annFile = os.path.join(args.data_dir, 'annotations/instances_train2014.json') - # init distributed - if args.is_distributed: - if args.device_target == "Ascend": - init() - else: - init("nccl") - args.rank = get_rank() - args.group_size = get_group_size() - # select for master rank save ckpt or all rank save, compatiable for model parallel args.rank_save_ckpt_flag = 0 if args.is_save_on_master: @@ -161,6 +152,14 @@ def train(): devid = int(os.getenv('DEVICE_ID', '0')) context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, device_target=args.device_target, save_graphs=True, device_id=devid) + # init distributed + if args.is_distributed: + if args.device_target == "Ascend": + init() + else: + init("nccl") + args.rank = get_rank() + args.group_size = get_group_size() if args.need_profiler: from mindspore.profiler.profiling import Profiler profiler = Profiler(output_path=args.outputs_dir, is_detail=True, is_show_op_path=True)