|
|
|
@ -32,33 +32,32 @@ DEFAULT_BACKEND = Backend("hccl")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_group(group):
|
|
|
|
|
"""Get the global world group if the group is default world comm group."""
|
|
|
|
|
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
|
|
|
|
if group == DEFAULT_WORLD_COMM_GROUP:
|
|
|
|
|
return GlobalComm.WORLD_COMM_GROUP
|
|
|
|
|
return group
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalComm:
|
|
|
|
|
"""Global communication info."""
|
|
|
|
|
"""World communication information."""
|
|
|
|
|
BACKEND = DEFAULT_BACKEND
|
|
|
|
|
WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init(backend_name=None):
|
|
|
|
|
"""
|
|
|
|
|
Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used.
|
|
|
|
|
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
The full name of hccl is Huawei Collective Communication Library.
|
|
|
|
|
The full name of nccl is NVIDIA Collective Communication Library.
|
|
|
|
|
The full name of HCCL is Huawei Collective Communication Library.
|
|
|
|
|
The full name of NCCL is NVIDIA Collective Communication Library.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
backend_name (str): Backend.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If backen_name is not a string.
|
|
|
|
|
RuntimeError: If device target is invalid.
|
|
|
|
|
RuntimeError: If backend is invalid or distributed init fails.
|
|
|
|
|
TypeError: If `backend_name` is not a string.
|
|
|
|
|
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails.
|
|
|
|
|
"""
|
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
|
return
|
|
|
|
@ -88,17 +87,17 @@ def init(backend_name=None):
|
|
|
|
|
|
|
|
|
|
def release():
|
|
|
|
|
"""
|
|
|
|
|
Release distributed resource. e.g., hccl/nccl.
|
|
|
|
|
Release distributed resource. e.g. HCCL/NCCL.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
RuntimeError: If distributed resource release fails.
|
|
|
|
|
RuntimeError: If failed to release distributed resource.
|
|
|
|
|
"""
|
|
|
|
|
finalize_hccl()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
"""
|
|
|
|
|
Gets rank ID for current device in specified collective communication group.
|
|
|
|
|
Get the rank ID for the current device in the specified collective communication group.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
group (str): ProcessGroup, the process group to work on. Default: WORLD_COMM_GROUP.
|
|
|
|
@ -109,7 +108,7 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a string.
|
|
|
|
|
ValueError: If backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
@ -130,14 +129,14 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a string.
|
|
|
|
|
ValueError: If backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
"""
|
|
|
|
|
Gets rank size of the specified collective communication group.
|
|
|
|
|
Get the rank size of the specified collective communication group.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
group (str): ProcessGroup, the process group to work on.
|
|
|
|
@ -148,7 +147,7 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a string.
|
|
|
|
|
ValueError: If backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
@ -164,22 +163,23 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
group (str): ProcessGroup, the process group to work on.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int, the local rank size where the calling process is being within the group.
|
|
|
|
|
int, the local rank size where the calling process is within the group.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a string.
|
|
|
|
|
ValueError: If backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_world_rank_from_group_rank(group, group_rank_id):
|
|
|
|
|
"""
|
|
|
|
|
Gets the rank ID in world communication group corresponding to the rank ID in specified user communication group.
|
|
|
|
|
Gets the rank ID in the world communication group corresponding to
|
|
|
|
|
the rank ID in the specified user communication group.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Nccl is not supported.
|
|
|
|
|
NCCL is not supported.
|
|
|
|
|
The parameter group should not be "hccl_world_group".
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -190,52 +190,53 @@ def get_world_rank_from_group_rank(group, group_rank_id):
|
|
|
|
|
int, the rank ID in world communication group.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group_rank_id is not a int or group is not a string.
|
|
|
|
|
TypeError: If `group_rank_id` is not an integer or the group is not a string.
|
|
|
|
|
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_group_rank_from_world_rank(world_rank_id, group):
|
|
|
|
|
"""
|
|
|
|
|
Gets the rank ID in specified user communication group corresponding to the rank ID in world communication group.
|
|
|
|
|
Get the rank ID in the specified user communication group corresponding to
|
|
|
|
|
the rank ID in the world communication group.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Nccl is not supported.
|
|
|
|
|
NCCL is not supported.
|
|
|
|
|
The parameter group should not be "hccl_world_group".
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
world_rank_id (int): A rank ID in world communication group.
|
|
|
|
|
world_rank_id (int): A rank ID in the world communication group.
|
|
|
|
|
group (str): The user communication group.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int, the rank ID in user communication group.
|
|
|
|
|
int, the rank ID in the user communication group.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If world_rank_id is not a int or group is not a string.
|
|
|
|
|
TypeError: If world_rank_id is not an integer or the group is not a string.
|
|
|
|
|
ValueError: If group is 'hccl_world_group' or backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_group(group, rank_ids):
|
|
|
|
|
"""
|
|
|
|
|
Creates user collective communication group.
|
|
|
|
|
Create a user collective communication group.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Nccl is not supported.
|
|
|
|
|
NCCL is not supported.
|
|
|
|
|
The size of rank_ids should be larger than 1.
|
|
|
|
|
Rank_ids should not have duplicate data.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
group (str): ProcessGroup, the process group to create.
|
|
|
|
|
rank_ids (list): List of device ID.
|
|
|
|
|
rank_ids (list): A list of device IDs.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a string or rank_ids is not a list.
|
|
|
|
|
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
|
|
|
|
|
TypeError: If group is not a string or `rank_ids` is not a list.
|
|
|
|
|
ValueError: If `rank_ids` size is not larger than 1, or `rank_ids` has duplicate data, or backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
Examples:
|
|
|
|
|
>>> group = "0-1"
|
|
|
|
@ -247,7 +248,7 @@ def create_group(group, rank_ids):
|
|
|
|
|
|
|
|
|
|
def destroy_group(group):
|
|
|
|
|
"""
|
|
|
|
|
Destroys user collective communication group.
|
|
|
|
|
Destroy the user collective communication group.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
Nccl is not supported.
|
|
|
|
@ -259,6 +260,6 @@ def destroy_group(group):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If group is not a string.
|
|
|
|
|
ValueError: If group is "hccl_world_group" or backend is invalid.
|
|
|
|
|
RuntimeError: If hccl/nccl is not available or nccl not supports.
|
|
|
|
|
RuntimeError: If HCCL/NCCL is not available or NCCL is not supported.
|
|
|
|
|
"""
|
|
|
|
|
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|
|
|
|
|