diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py index 67b45e1344..016d02d6ac 100644 --- a/model_zoo/official/cv/alexnet/train.py +++ b/model_zoo/official/cv/alexnet/train.py @@ -76,8 +76,8 @@ if __name__ == "__main__": gradients_mean=True) init() elif device_target == "GPU": - init() if device_num > 1: + init() context.reset_auto_parallel_context() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)