check whether do init in distributed scene

pull/13857/head
yao_yf 4 years ago
parent c4e339674b
commit f5763bdebb

@ -77,6 +77,13 @@ class Backend:
raise ValueError("Invalid backend: '{}'".format(name)) raise ValueError("Invalid backend: '{}'".format(name))
return value return value
DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
def is_hccl_available(): def is_hccl_available():
""" """
@ -114,6 +121,8 @@ def check_parameter_available(func):
def wrapper(*args, **kargs): def wrapper(*args, **kargs):
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs) return func(*args, **kargs)
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
group = None group = None
if "group" in kargs.keys(): if "group" in kargs.keys():
group = kargs.get("group") group = kargs.get("group")

@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper _get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
DEFAULT_BACKEND = Backend("hccl")
def _get_group(group): def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
@ -38,11 +36,6 @@ def _get_group(group):
return group return group
class GlobalComm:
"""World communication information."""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
def init(backend_name=None): def init(backend_name=None):
""" """
@ -78,10 +71,12 @@ def init(backend_name=None):
init_hccl() init_hccl()
GlobalComm.BACKEND = Backend("hccl") GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl": elif backend_name == "nccl":
init_gpu_collective() init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl") GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
else: else:
raise RuntimeError("Backend name {} is not supported.".format(backend_name)) raise RuntimeError("Backend name {} is not supported.".format(backend_name))

Loading…
Cancel
Save