remove device_id setting when use GPU

pull/7078/head
caojian05 4 years ago
parent 13bf28e538
commit bfe85189d2

@ -121,7 +121,7 @@ def test(cloud_args=None):
args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.device_target, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit():
if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend":
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
args.outputs_dir = os.path.join(args.log_path,

@ -141,6 +141,7 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

Loading…
Cancel
Save