|
|
|
@ -121,7 +121,7 @@ def test(cloud_args=None):
|
|
|
|
|
args = parse_args(cloud_args)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
|
|
|
|
|
device_target=args.device_target, save_graphs=False)
|
|
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
|
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit() and args.device_target == "Ascend":
|
|
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
|
|
|
|
|
|
args.outputs_dir = os.path.join(args.log_path,
|
|
|
|
|