|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""Communication management API"""
|
|
|
|
|
import os
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
|
|
|
|
from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
|
|
|
|
_get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \
|
|
|
|
@ -45,7 +46,7 @@ class GlobalComm:
|
|
|
|
|
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init(backend_name="hccl"):
|
|
|
|
|
def init(backend_name=None):
|
|
|
|
|
"""
|
|
|
|
|
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used.
|
|
|
|
|
|
|
|
|
@ -57,11 +58,20 @@ def init(backend_name="hccl"):
|
|
|
|
|
backend_name (str): Backend.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If backend name is not a string.
|
|
|
|
|
TypeError: If backen_name is not a string.
|
|
|
|
|
RuntimeError: If device target is invalid.
|
|
|
|
|
RuntimeError: If backend is invalid or distributed init fails.
|
|
|
|
|
"""
|
|
|
|
|
if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
|
|
|
|
|
return
|
|
|
|
|
if backend_name is None:
|
|
|
|
|
device_target = context.get_context("device_target")
|
|
|
|
|
if device_target == "Ascend":
|
|
|
|
|
backend_name = "hccl"
|
|
|
|
|
elif device_target == "GPU":
|
|
|
|
|
backend_name = "nccl"
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Device target {} is not supported.".format(device_target))
|
|
|
|
|
if not isinstance(backend_name, str):
|
|
|
|
|
raise TypeError("Backend name must be a string, but got {}".format(type(backend_name)))
|
|
|
|
|
|
|
|
|
|