Enable get rank id and size by group

pull/3214/head
ZPaC 5 years ago
parent ade60ad3d3
commit 0bc74f28c5

@ -279,6 +279,9 @@ if (ENABLE_GPU)
${CUDNN_PATH}/lib64/libcudnn.so ${CUDNN_PATH}/lib64/libcudnn.so
${CUDA_PATH}/lib64/libcudart.so ${CUDA_PATH}/lib64/libcudart.so
${CUDA_PATH}/lib64/stubs/libcuda.so) ${CUDA_PATH}/lib64/stubs/libcuda.so)
if (ENABLE_MPI)
set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH})
endif()
endif () endif ()
if (ENABLE_CPU) if (ENABLE_CPU)

@ -99,5 +99,11 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int) BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -98,7 +98,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"TensorAdd", BROADCAST_TYPE_ADD},
}; };
auto iter = kBroadcastTypeMap.find(kernel_name); auto iter = kBroadcastTypeMap.find(kernel_name);

@ -24,17 +24,28 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
NcclGpuKernel, half) NcclGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(AllReduce,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
NcclGpuKernel, float) NcclGpuKernel, float)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
NcclGpuKernel, half) NcclGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(AllGather,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclGpuKernel, int)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
NcclGpuKernel, float) NcclGpuKernel, float)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
NcclGpuKernel, half) NcclGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ReduceScatter,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclGpuKernel, int)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -70,9 +70,7 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
mindspore::parallel::Group *const group) { mindspore::parallel::Group *const group) {
// it is simple to use size to determine whether it is a world group // it is simple to use size to determine whether it is a world group
uint32_t world_size = 0; uint32_t world_size = 0;
if (world_group_ != NCCL_WORLD_GROUP) {
(void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size);
}
if (devices.size() == world_size) { if (devices.size() == world_size) {
auto it = groups_.find(world_group_); auto it = groups_.find(world_group_);

@ -55,6 +55,7 @@ if (ENABLE_GPU)
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS})
target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl)
target_link_libraries(_ms_mpi PRIVATE gpu_collective)
endif () endif ()
# add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST})

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_
#include <nccl.h>
#include <sstream> #include <sstream>
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
@ -25,6 +26,12 @@ namespace device {
namespace gpu { namespace gpu {
constexpr int MAX_HOSTNAME_LEN = 1024; constexpr int MAX_HOSTNAME_LEN = 1024;
constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group";
struct NcclGroupInfo {
int size;
int rank;
ncclUniqueId unique_id;
ncclComm_t comm;
};
#define CHECK_RET(expression, result, message) \ #define CHECK_RET(expression, result, message) \
{ \ { \
auto ret = (expression); \ auto ret = (expression); \

@ -14,58 +14,37 @@
* limitations under the License. * limitations under the License.
*/ */
#include <mpi.h>
#include <nccl.h>
#include <unistd.h>
#include <memory>
#include <string> #include <string>
#include <iostream>
#include <vector> #include <vector>
#include "runtime/device/gpu/distribution/mpi_wrapper.h" #include "runtime/device/gpu/distribution/collective_wrapper.h"
#include "runtime/device/gpu/distribution/nccl_wrapper.h"
#ifndef EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); }
#define EXPORT_WRAPPER __attribute__((visibility("default")))
#endif
using MPIWrapper = mindspore::device::gpu::MPIWrapper; int local_rank_id() { return MPIWrapper::instance().local_rank_id(); }
using NCCLWrapper = mindspore::device::gpu::NCCLWrapper;
extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); }
extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) {
extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); }
extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) {
return MPIWrapper::instance().CreateCommGroup(group_name, ranks); return MPIWrapper::instance().CreateCommGroup(group_name, ranks);
} }
extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) { int GetRankIDByGroup(const std::string &group_name) { return MPIWrapper::instance().GetRankIDByGroup(group_name); }
return MPIWrapper::instance().GetRankIDByGroup(group_name);
}
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) { int GetGroupSize(const std::string &group_name) { return MPIWrapper::instance().GetGroupSize(group_name); }
return MPIWrapper::instance().GetGroupSize(group_name);
}
extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) { bool DestroyGroup(const std::string &group_name) { return MPIWrapper::instance().DestroyGroup(group_name); }
return MPIWrapper::instance().DestroyGroup(group_name);
}
extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
ncclDataType_t data_type, ncclRedOp_t reduce_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) {
cudaStream_t stream) { return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream, group);
return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream);
} }
extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
ncclDataType_t data_type, cudaStream_t stream) { cudaStream_t stream, const std::string &group) {
return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream, group);
} }
extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
ncclDataType_t data_type, ncclRedOp_t reduce_type, ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) {
cudaStream_t stream) { return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream, group);
return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream);
} }

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <mpi.h>
#include <nccl.h>
#include <unistd.h>
#include <string>
#include <vector>
#include "runtime/device/gpu/distribution/mpi_wrapper.h"
#include "runtime/device/gpu/distribution/nccl_wrapper.h"
#ifndef EXPORT_WRAPPER
#define EXPORT_WRAPPER __attribute__((visibility("default")))
#endif
using MPIWrapper = mindspore::device::gpu::MPIWrapper;
using NCCLWrapper = mindspore::device::gpu::NCCLWrapper;
extern "C" EXPORT_WRAPPER void InitMPI();
extern "C" EXPORT_WRAPPER int local_rank_id();
extern "C" EXPORT_WRAPPER void InitNCCLComm();
extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks);
extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name);
extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name);
extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name);
extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count,
ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream,
const std::string &group);
extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count,
ncclDataType_t data_type, cudaStream_t stream,
const std::string &group);
extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count,
ncclDataType_t data_type, ncclRedOp_t reduce_type,
cudaStream_t stream, const std::string &group);

@ -58,7 +58,7 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto
if (rank_id_ == ranks[0]) { if (rank_id_ == ranks[0]) {
group_unique_id = NCCLWrapper::instance().nccl_unique_id(); group_unique_id = NCCLWrapper::instance().nccl_unique_id();
} }
MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm); MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, 0, mpi_group_comm);
int group_rank[1]; int group_rank[1];
int global_rank[1] = {rank_id_}; int global_rank[1] = {rank_id_};
@ -68,9 +68,8 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto
return false; return false;
} }
ncclComm_t nccl_group_comm; NcclGroupInfo nccl_group = {static_cast<int>(ranks.size()), group_rank[0], group_unique_id, nullptr};
NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]); NCCLWrapper::instance().AddGroupInfo(group_name, &nccl_group);
NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm);
return true; return true;
} }
@ -111,7 +110,6 @@ void MPIWrapper::Init() {
CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id.");
CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size.");
NCCLWrapper::instance().set_rank(rank_id_, rank_size_);
AssignLocalRankID(); AssignLocalRankID();
CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD"); CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD");
@ -123,7 +121,9 @@ void MPIWrapper::Init() {
} }
CHECK_RET(MPI_Bcast(reinterpret_cast<void *>(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), CHECK_RET(MPI_Bcast(reinterpret_cast<void *>(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD),
MPI_SUCCESS, "Failed to broadcast nccl unique id."); MPI_SUCCESS, "Failed to broadcast nccl unique id.");
NCCLWrapper::instance().set_nccl_unique_id(unique_id);
NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr};
NCCLWrapper::instance().AddGroupInfo(NCCL_WORLD_GROUP, &world_group);
return; return;
} }

@ -30,60 +30,58 @@ ncclUniqueId NCCLWrapper::nccl_unique_id() const {
return unique_id; return unique_id;
} }
void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; }
void NCCLWrapper::set_rank(int rank_id, int rank_size) {
rank_id_ = rank_id;
rank_size_ = rank_size;
}
void NCCLWrapper::InitNCCLComm() { void NCCLWrapper::InitNCCLComm() {
CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, for (auto group : group_info_) {
"Failed to init nccl communicator."); std::string group_name = group.first;
group_to_comm_map_[NCCL_WORLD_GROUP] = comm_; NcclGroupInfo group_info = group.second;
CHECK_RET(ncclCommInitRank(&(group_info.comm), group_info.size, group_info.unique_id, group_info.rank), ncclSuccess,
"Failed to init nccl communicator for group " + group_name);
group_info_[group_name].comm = group_info.comm;
} }
comm_init_done_ = true;
void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) {
CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator.");
} }
ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) {
CHECK_RET(group_to_comm_map_.count(group_name), 1, CHECK_RET(group_info_.count(group_name), 1,
"Failed to find NCCL communicator for AllReduce by the group name " + group_name); "Failed to find NCCL communicator for AllReduce by the group name " + group_name);
ncclComm_t group_comm = group_to_comm_map_[group_name]; ncclComm_t group_comm = group_info_[group_name].comm;
return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
} }
ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type,
cudaStream_t stream, const std::string &group_name) { cudaStream_t stream, const std::string &group_name) {
CHECK_RET(group_to_comm_map_.count(group_name), 1, CHECK_RET(group_info_.count(group_name), 1,
"Failed to find NCCL communicator for AllGather by the group name " + group_name); "Failed to find NCCL communicator for AllGather by the group name " + group_name);
ncclComm_t group_comm = group_to_comm_map_[group_name]; ncclComm_t group_comm = group_info_[group_name].comm;
return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream); return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream);
} }
ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count,
ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream,
const std::string &group_name) { const std::string &group_name) {
CHECK_RET(group_to_comm_map_.count(group_name), 1, CHECK_RET(group_info_.count(group_name), 1,
"Failed to find NCCL communicator for ReduceScatter by the group name " + group_name); "Failed to find NCCL communicator for ReduceScatter by the group name " + group_name);
ncclComm_t group_comm = group_to_comm_map_[group_name]; ncclComm_t group_comm = group_info_[group_name].comm;
return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream);
} }
void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) { void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) {
group_to_comm_map_[group_name] = comm; if (comm_init_done_) {
CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess,
"Failed to init nccl communicator for group " + group_name);
}
group_info_[group_name] = *group;
} }
void NCCLWrapper::DestroyGroup(const std::string &group_name) { void NCCLWrapper::DestroyGroup(const std::string &group_name) {
auto group_iter = group_to_comm_map_.find(group_name); auto group_iter = group_info_.find(group_name);
if (group_iter == group_to_comm_map_.end()) { if (group_iter == group_info_.end()) {
return; return;
} }
group_to_comm_map_.erase(group_iter); ncclComm_t group_comm = group_iter->second.comm;
ncclComm_t group_comm = group_iter->second;
CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name); CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name);
group_info_.erase(group_iter);
return; return;
} }
} // namespace gpu } // namespace gpu

@ -33,29 +33,23 @@ class NCCLWrapper {
NCCLWrapper &operator=(const NCCLWrapper &) = delete; NCCLWrapper &operator=(const NCCLWrapper &) = delete;
static NCCLWrapper &instance(); static NCCLWrapper &instance();
ncclUniqueId nccl_unique_id() const; ncclUniqueId nccl_unique_id() const;
void set_nccl_unique_id(ncclUniqueId unique_id);
void set_rank(int rank_id, int rank_size);
void InitNCCLComm(); void InitNCCLComm();
void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank);
ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP);
void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm); void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group);
void DestroyGroup(const std::string &group_name); void DestroyGroup(const std::string &group_name);
private: private:
NCCLWrapper() : rank_id_(-1), rank_size_(0) {} NCCLWrapper() : comm_init_done_(false) {}
~NCCLWrapper() = default; ~NCCLWrapper() = default;
private: private:
int rank_id_; bool comm_init_done_;
int rank_size_; std::map<std::string, NcclGroupInfo> group_info_;
ncclUniqueId unique_id_;
ncclComm_t comm_;
std::map<std::string, ncclComm_t> group_to_comm_map_;
}; };
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device

@ -15,45 +15,24 @@
*/ */
#include "runtime/device/gpu/mpi/mpi_initializer.h" #include "runtime/device/gpu/mpi/mpi_initializer.h"
#include <dlfcn.h>
#include <mpi.h> #include <mpi.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <iostream> #include <iostream>
#include <string>
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
MPIInitializer::MPIInitializer() {
int init_flag = 0;
if (MPI_Initialized(&init_flag) != MPI_SUCCESS) {
return;
}
if (init_flag == 0) {
auto ret = MPI_Init(nullptr, nullptr);
if (ret != MPI_SUCCESS) {
return;
}
}
MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_);
MPI_Comm_size(MPI_COMM_WORLD, &rank_size_);
}
MPIInitializer::~MPIInitializer() {
int finalized_flag = 0;
(void)MPI_Finalized(&finalized_flag);
if (finalized_flag == 0) {
(void)MPI_Finalize();
}
}
MPIInitializer &MPIInitializer::GetInstance() { MPIInitializer &MPIInitializer::GetInstance() {
static MPIInitializer instance; static MPIInitializer instance;
return instance; return instance;
} }
int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } int MPIInitializer::get_rank_id(const std::string &group) { return GetRankIDByGroup(group); }
int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } int MPIInitializer::get_rank_size(const std::string &group) { return GetGroupSize(group); }
PYBIND11_MODULE(_ms_mpi, mpi_initializer) { PYBIND11_MODULE(_ms_mpi, mpi_initializer) {
mpi_initializer.doc() = "mindspore mpi python wrapper"; mpi_initializer.doc() = "mindspore mpi python wrapper";

@ -17,6 +17,9 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_
#include <string>
#include "runtime/device/gpu/distribution/collective_wrapper.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace gpu { namespace gpu {
@ -25,15 +28,12 @@ class MPIInitializer {
MPIInitializer(MPIInitializer const &) = delete; MPIInitializer(MPIInitializer const &) = delete;
MPIInitializer &operator=(const MPIInitializer &) = delete; MPIInitializer &operator=(const MPIInitializer &) = delete;
static MPIInitializer &GetInstance(); static MPIInitializer &GetInstance();
static int get_rank_id(); static int get_rank_id(const std::string &group);
static int get_rank_size(); static int get_rank_size(const std::string &groups);
private: private:
MPIInitializer(); MPIInitializer() = default;
~MPIInitializer(); ~MPIInitializer() = default;
int rank_id_;
int rank_size_;
}; };
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device

@ -163,10 +163,7 @@ def _get_rank_helper(group, backend):
else: else:
rank_id = hccl.get_rank_id(group) rank_id = hccl.get_rank_id(group)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
if group == NCCL_WORLD_COMM_GROUP: rank_id = mpi.get_rank_id(group)
rank_id = mpi.get_rank_id()
else:
raise RuntimeError("Nccl doesn't support get_rank_id by user group now.")
else: else:
raise ValueError("Invalid backend: '{}'".format(backend)) raise ValueError("Invalid backend: '{}'".format(backend))
return rank_id return rank_id
@ -225,10 +222,7 @@ def _get_size_helper(group, backend):
else: else:
size = hccl.get_rank_size(group) size = hccl.get_rank_size(group)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
if group == NCCL_WORLD_COMM_GROUP: size = mpi.get_rank_size(group)
size = mpi.get_rank_size()
else:
raise RuntimeError("Nccl doesn't support get_rank_size by user group now.")
else: else:
raise ValueError("Invalid backend: '{}'".format(backend)) raise ValueError("Invalid backend: '{}'".format(backend))
return size return size

@ -22,6 +22,7 @@ equal_op_info = AkgGpuRegOp("Equal") \
.output(0, "output") \ .output(0, "output") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
.get_op_info() .get_op_info()

Loading…
Cancel
Save