From f5763bdebb3ed4caa9337c19594268912f3d839a Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 23 Mar 2021 17:08:14 +0800 Subject: [PATCH] check whether do init in distributed scene --- mindspore/communication/_comm_helper.py | 9 +++++++++ mindspore/communication/management.py | 11 +++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index f70616ce0a..46f77f4ed8 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -77,6 +77,13 @@ class Backend: raise ValueError("Invalid backend: '{}'".format(name)) return value +DEFAULT_BACKEND = Backend("hccl") + +class GlobalComm: + """World communication information.""" + BACKEND = DEFAULT_BACKEND + WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP + INITED = False def is_hccl_available(): """ @@ -114,6 +121,8 @@ def check_parameter_available(func): def wrapper(*args, **kargs): if _is_role_pserver() or _is_role_sched(): return func(*args, **kargs) + if not GlobalComm.INITED: + raise RuntimeError("Distributed Communication has not been inited") group = None if "group" in kargs.keys(): group = kargs.get("group") diff --git a/mindspore/communication/management.py b/mindspore/communication/management.py index 122b4515af..36febb0e46 100755 --- a/mindspore/communication/management.py +++ b/mindspore/communication/management.py @@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ - _get_local_rank_helper, _get_local_size_helper + _get_local_rank_helper, _get_local_size_helper, GlobalComm from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective @@ -28,8 +28,6 @@ __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP -DEFAULT_BACKEND = Backend("hccl") - def _get_group(group): """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" @@ -38,11 +36,6 @@ def _get_group(group): return group -class GlobalComm: - """World communication information.""" - BACKEND = DEFAULT_BACKEND - WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP - def init(backend_name=None): """ @@ -78,10 +71,12 @@ def init(backend_name=None): init_hccl() GlobalComm.BACKEND = Backend("hccl") GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP + GlobalComm.INITED = True elif backend_name == "nccl": init_gpu_collective() GlobalComm.BACKEND = Backend("nccl") GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP + GlobalComm.INITED = True else: raise RuntimeError("Backend name {} is not supported.".format(backend_name))