|
|
|
@ -24,7 +24,7 @@ std::shared_ptr<NCCLWrapper> NCCLWrapper::s_instance_ = NULL;
|
|
|
|
|
bool NCCLWrapper::is_initialized_ = false;
|
|
|
|
|
|
|
|
|
|
void NCCLWrapper::InitNCCL() {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
|
|
|
|
|
&(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_,
|
|
|
|
|
nccl_info_.my_global_rank_));
|
|
|
|
@ -33,14 +33,14 @@ void NCCLWrapper::InitNCCL() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
nccl_info_.nccl_id_ = nccl_info.nccl_id_;
|
|
|
|
|
#endif
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NCCLInfo NCCLWrapper::GetNCCLId() {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_)));
|
|
|
|
|
#endif
|
|
|
|
|
return nccl_info_;
|
|
|
|
@ -48,7 +48,7 @@ NCCLInfo NCCLWrapper::GetNCCLId() {
|
|
|
|
|
|
|
|
|
|
void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
|
|
|
|
|
const int ranks) {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
nccl_info_.local_rank_ = local_rank;
|
|
|
|
|
nccl_info_.my_global_rank_ = global_rank;
|
|
|
|
|
nccl_info_.global_ranks_ = ranks;
|
|
|
|
@ -60,7 +60,7 @@ void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank,
|
|
|
|
|
|
|
|
|
|
void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope,
|
|
|
|
|
const std::vector<std::string>& var_names) {
|
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL)
|
|
|
|
|
for (auto& name : var_names) {
|
|
|
|
|
auto var = scope.FindVar(name);
|
|
|
|
|
LoDTensor* tensor = var->GetMutable<LoDTensor>();
|
|
|
|
|