You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
7.6 KiB
255 lines
7.6 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""HCCL management API"""
|
|
import ctypes
|
|
|
|
MAX_GROUP_NAME_LEN = 127
|
|
MAX_RANK_NUM = 4096
|
|
HCCL_LIB = 'libhccl.so'
|
|
HCCL_LIB_CTYPES = ""
|
|
|
|
|
|
def check_group(group):
|
|
"""
|
|
A function that check if a collection communication group is leagal.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if isinstance(group, (str)):
|
|
group_len = len(group)
|
|
if group_len > MAX_GROUP_NAME_LEN or group_len == 0:
|
|
raise ValueError('Group name is invalid.')
|
|
else:
|
|
raise TypeError('Group must be a python str.')
|
|
|
|
|
|
def check_rank_num(rank_num):
|
|
"""
|
|
A function that check if a collection communication rank number is leagal.If not raise error.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if isinstance(rank_num, (int)):
|
|
if rank_num > MAX_RANK_NUM or rank_num <= 0:
|
|
raise ValueError('Rank number is out of range.')
|
|
else:
|
|
raise TypeError('Rank number must be a python int.')
|
|
|
|
|
|
def check_rank_id(rank_id):
|
|
"""
|
|
A function that check if a collection communication rank id is leagal.If not raise error.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
if isinstance(rank_id, (int)):
|
|
if rank_id >= MAX_RANK_NUM or rank_id < 0:
|
|
raise ValueError('Rank id is out of range.')
|
|
else:
|
|
raise TypeError('Rank id must be a python int.')
|
|
|
|
|
|
def load_lib():
|
|
try:
|
|
hccl_lib = ctypes.CDLL(HCCL_LIB)
|
|
except Exception:
|
|
raise RuntimeError('Get hccl lib error.')
|
|
global HCCL_LIB_CTYPES
|
|
HCCL_LIB_CTYPES = hccl_lib
|
|
|
|
|
|
def c_str(string):
|
|
"""Convert a python string to C string."""
|
|
if not isinstance(string, str):
|
|
string = string.decode('ascii')
|
|
return ctypes.c_char_p(string.encode('utf-8'))
|
|
|
|
|
|
def c_array(ctype, values):
|
|
"""Create ctypes array from a python array."""
|
|
return (ctype * len(values))(*values)
|
|
|
|
|
|
def create_group(group, rank_num, rank_ids):
|
|
"""
|
|
Create group.
|
|
|
|
A function that creates a collection communication group which includes 'rank_num'
|
|
device and 'rank_ids' is the list of these ranks of devices.
|
|
|
|
Note:
|
|
The world group can not be created.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
check_group(group)
|
|
check_rank_num(rank_num)
|
|
if isinstance(rank_ids, (list)):
|
|
if rank_num != len(rank_ids):
|
|
raise ValueError('Rank number is not equal to the length of rank_ids.')
|
|
for rank_id in rank_ids:
|
|
if not isinstance(rank_id, (int)) or rank_id < 0:
|
|
raise ValueError('Rank id must be unsigned integer!')
|
|
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
|
|
c_rank_num = ctypes.c_uint(rank_num)
|
|
c_group = c_str(group)
|
|
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
|
|
if ret != 0:
|
|
raise RuntimeError('Create group error.')
|
|
else:
|
|
raise TypeError('Rank ids must be a python list.')
|
|
|
|
|
|
def destroy_group(group):
|
|
"""
|
|
A function that destroy the group which created by user.
|
|
|
|
Note:
|
|
The world group can not be destroy.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
check_group(group)
|
|
c_group = c_str(group)
|
|
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
|
|
if ret != 0:
|
|
raise RuntimeError('Destroy group error.')
|
|
|
|
|
|
def get_rank_size(group="hccl_world_group"):
|
|
"""
|
|
A function that returns the number of ranks within the given collection communication group.
|
|
|
|
Note:
|
|
The default group is hccl_world_group.
|
|
|
|
Returns:
|
|
An integer scalar with the num of ranks.
|
|
"""
|
|
check_group(group)
|
|
c_group = c_str(group)
|
|
c_rank_size = ctypes.c_uint()
|
|
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
|
|
if ret != 0:
|
|
raise RuntimeError('Get rank size error.')
|
|
|
|
return c_rank_size.value
|
|
|
|
|
|
def get_rank_id(group="hccl_world_group"):
|
|
"""
|
|
A function that returns the rank id of the calling process, within the given collection communication group.
|
|
|
|
Returns:
|
|
An integer scalar with the rank id of the calling process.
|
|
"""
|
|
check_group(group)
|
|
c_group = c_str(group)
|
|
c_rank_id = ctypes.c_uint()
|
|
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
|
|
if ret != 0:
|
|
raise RuntimeError('Get rank id error.')
|
|
|
|
return c_rank_id.value
|
|
|
|
|
|
def get_local_rank_size(group="hccl_world_group"):
|
|
"""
|
|
A function that returns the number of local ranks within the given collection communication group.
|
|
|
|
Note:
|
|
The default group is hccl_world_group.
|
|
|
|
Returns:
|
|
An integer scalar with the num of local ranks.
|
|
"""
|
|
check_group(group)
|
|
c_group = c_str(group)
|
|
c_local_rank_size = ctypes.c_uint()
|
|
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
|
|
if ret != 0:
|
|
raise RuntimeError('Get local rank size error.')
|
|
|
|
return c_local_rank_size.value
|
|
|
|
|
|
def get_local_rank_id(group="hccl_world_group"):
|
|
"""
|
|
Get local rank id.
|
|
|
|
A function that returns the local rank id of the calling process, within the given collection communication group.
|
|
|
|
Returns:
|
|
An integer scalar with the local rank id of the calling process.
|
|
"""
|
|
check_group(group)
|
|
c_group = c_str(group)
|
|
c_local_rank_id = ctypes.c_uint()
|
|
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
|
|
if ret != 0:
|
|
raise RuntimeError('Get local rank id error.')
|
|
|
|
return c_local_rank_id.value
|
|
|
|
|
|
def get_world_rank_from_group_rank(group, group_rank_id):
|
|
"""
|
|
Get world rank from group rank.
|
|
|
|
A function that returns the rank id in the world group corresponding to the
|
|
rank which id is 'group_rank_id' in the user group.
|
|
|
|
Returns:
|
|
An integer scalar with the rank id in the world group.
|
|
"""
|
|
check_group(group)
|
|
check_rank_id(group_rank_id)
|
|
c_group = c_str(group)
|
|
c_group_rank_id = ctypes.c_uint(group_rank_id)
|
|
c_world_rank_id = ctypes.c_uint()
|
|
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
|
|
if ret != 0:
|
|
raise RuntimeError('Get world rank from group rank error.')
|
|
|
|
return c_world_rank_id.value
|
|
|
|
|
|
def get_group_rank_from_world_rank(world_rank_id, group):
|
|
"""
|
|
Get group rank from world rank.
|
|
|
|
A function that returns the rank id in the user group corresponding to the
|
|
rank which id is 'world_rank_id' in the world group.
|
|
|
|
Returns:
|
|
An integer scalar with the rank id in the user group.
|
|
"""
|
|
check_group(group)
|
|
check_rank_id(world_rank_id)
|
|
c_group = c_str(group)
|
|
c_world_rank_id = ctypes.c_uint(world_rank_id)
|
|
c_group_rank_id = ctypes.c_uint()
|
|
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
|
|
if ret != 0:
|
|
raise RuntimeError('Get group rank from world rank error.')
|
|
|
|
return c_group_rank_id.value
|