!13857 check whether communication unit has been inited

From: @yao_yf
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/13857/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b1c86b6a22

@ -77,6 +77,13 @@ class Backend:
raise ValueError("Invalid backend: '{}'".format(name))
return value
DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
def is_hccl_available():
"""
@ -114,6 +121,8 @@ def check_parameter_available(func):
def wrapper(*args, **kargs):
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
group = None
if "group" in kargs.keys():
group = kargs.get("group")

@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper
_get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")
def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
@ -38,11 +36,6 @@ def _get_group(group):
return group
class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
def init(backend_name=None):
"""
@ -78,10 +71,12 @@ def init(backend_name=None):
init_hccl()
GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl":
init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name))

Loading…
Cancel
Save