|
|
|
@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
|
|
|
|
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
|
|
|
|
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
|
|
|
|
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
|
|
|
|
|
_get_local_rank_helper, _get_local_size_helper
|
|
|
|
|
_get_local_rank_helper, _get_local_size_helper, GlobalComm
|
|
|
|
|
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
|
|
|
|
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
|
|
|
|
|
|
|
|
|
|
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
|
|
|
DEFAULT_BACKEND = Backend("hccl")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_group(group):
|
|
|
|
|
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
|
|
|
@ -38,11 +36,6 @@ def _get_group(group):
|
|
|
|
|
return group
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalComm:
|
|
|
|
|
"""World communication information."""
|
|
|
|
|
BACKEND = DEFAULT_BACKEND
|
|
|
|
|
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init(backend_name=None):
|
|
|
|
|
"""
|
|
|
|
@ -78,10 +71,12 @@ def init(backend_name=None):
|
|
|
|
|
init_hccl()
|
|
|
|
|
GlobalComm.BACKEND = Backend("hccl")
|
|
|
|
|
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
|
|
|
GlobalComm.INITED = True
|
|
|
|
|
elif backend_name == "nccl":
|
|
|
|
|
init_gpu_collective()
|
|
|
|
|
GlobalComm.BACKEND = Backend("nccl")
|
|
|
|
|
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
|
|
|
|
|
GlobalComm.INITED = True
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Backend name {} is not supported.".format(backend_name))
|
|
|
|
|
|
|
|
|
|