|
|
@ -62,8 +62,8 @@ def init(backend_name=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
return
|
|
|
|
return
|
|
|
|
if backend_name is None:
|
|
|
|
|
|
|
|
device_target = context.get_context("device_target")
|
|
|
|
device_target = context.get_context("device_target")
|
|
|
|
|
|
|
|
if backend_name is None:
|
|
|
|
if device_target == "Ascend":
|
|
|
|
if device_target == "Ascend":
|
|
|
|
backend_name = "hccl"
|
|
|
|
backend_name = "hccl"
|
|
|
|
elif device_target == "GPU":
|
|
|
|
elif device_target == "GPU":
|
|
|
@ -74,6 +74,8 @@ def init(backend_name=None):
|
|
|
|
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
|
|
|
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
|
|
|
|
|
|
|
|
|
|
|
if backend_name == "hccl":
|
|
|
|
if backend_name == "hccl":
|
|
|
|
|
|
|
|
if device_target != "Ascend":
|
|
|
|
|
|
|
|
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
|
|
|
|
init_hccl()
|
|
|
|
init_hccl()
|
|
|
|
GlobalComm.BACKEND = Backend("hccl")
|
|
|
|
GlobalComm.BACKEND = Backend("hccl")
|
|
|
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
|
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
|
|