diff --git a/model_zoo/official/cv/vgg16/eval.py b/model_zoo/official/cv/vgg16/eval.py index 324ec047bc..6f228a6a03 100644 --- a/model_zoo/official/cv/vgg16/eval.py +++ b/model_zoo/official/cv/vgg16/eval.py @@ -121,7 +121,7 @@ def test(cloud_args=None): args = parse_args(cloud_args) context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, device_target=args.device_target, save_graphs=False) - if os.getenv('DEVICE_ID', "not_set").isdigit(): + if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend": context.set_context(device_id=int(os.getenv('DEVICE_ID'))) args.outputs_dir = os.path.join(args.log_path, diff --git a/model_zoo/official/cv/vgg16/train.py b/model_zoo/official/cv/vgg16/train.py index 5b889540cc..29bc344e57 100644 --- a/model_zoo/official/cv/vgg16/train.py +++ b/model_zoo/official/cv/vgg16/train.py @@ -141,7 +141,8 @@ if __name__ == '__main__': context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) else: - context.set_context(device_id=args.device_id) + if args.device_target == "Ascend": + context.set_context(device_id=args.device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) # select for master rank save ckpt or all rank save, compatible for model parallel