|
|
|
@ -191,12 +191,13 @@ if __name__ == '__main__':
|
|
|
|
|
if args.is_distributed:
|
|
|
|
|
if args.device_target == "Ascend":
|
|
|
|
|
init()
|
|
|
|
|
context.set_context(device_id=args.device_id)
|
|
|
|
|
elif args.device_target == "GPU":
|
|
|
|
|
init("nccl")
|
|
|
|
|
args.rank = get_rank()
|
|
|
|
|
args.group_size = get_group_size()
|
|
|
|
|
device_num = args.group_size
|
|
|
|
|
|
|
|
|
|
args.rank = get_rank()
|
|
|
|
|
args.group_size = get_group_size()
|
|
|
|
|
device_num = args.group_size
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
mirror_mean=True)
|
|
|
|
|