diff --git a/model_zoo/official/cv/FCN8s/train.py b/model_zoo/official/cv/FCN8s/train.py index f076ff7589..4315be3856 100644 --- a/model_zoo/official/cv/FCN8s/train.py +++ b/model_zoo/official/cv/FCN8s/train.py @@ -49,12 +49,14 @@ def train(): context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False, device_target="Ascend", device_id=args.device_id) # init multicards training + args.rank = 0 + args.group_size = 1 if device_num > 1: parallel_mode = ParallelMode.DATA_PARALLEL context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num) init() - args.rank = get_rank() - args.group_size = get_group_size() + args.rank = get_rank() + args.group_size = get_group_size() # dataset dataset = data_generator.SegDataset(image_mean=cfg.image_mean,