|
|
|
@ -71,7 +71,6 @@ if __name__ == '__main__':
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1))
|
|
|
|
if device_num > 1:
|
|
|
|
if device_num > 1:
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
context.set_context(enable_hccl=True)
|
|
|
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
mirror_mean=True)
|
|
|
|
mirror_mean=True)
|
|
|
|
init()
|
|
|
|
init()
|
|
|
|
|