diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 2c298df4ec..e609b46593 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -81,11 +81,11 @@ if __name__ == '__main__': init() # GPU target else: + init("nccl") context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) if args_opt.net == "resnet50": auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) - init("nccl") ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" # create dataset