Add NCCL Group Guard

helinwang-patch-1
Yu Yang 7 years ago
parent 99fe83a020
commit 41ad632341

@ -300,8 +300,6 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::EnforceNotMet> exception_;
};
static std::mutex g_nccl_mtx_;
struct NCCLAllReduceOpHandle : public OpHandle {
ParallelExecutorPrivate *member_;
@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
int dtype = -1;
size_t numel = 0;
std::lock_guard<std::mutex> g(g_nccl_mtx_);
PADDLE_ENFORCE(platform::dynload::ncclGroupStart());
platform::NCCLGroupGuard guard;
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
auto &p = member_->places_[i];
@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
nccl_ctx.comm, nccl_ctx.stream()));
}
PADDLE_ENFORCE(platform::dynload::ncclGroupEnd());
}
}
};

@ -14,6 +14,7 @@
#pragma once
#include <thread>
#include <typeindex>
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
@ -33,5 +34,24 @@ inline ncclDataType_t ToNCCLDataType(std::type_index type) {
}
}
class NCCLGroupGuard {
public:
inline NCCLGroupGuard() {
mutex().lock();
PADDLE_ENFORCE(dynload::ncclGroupStart());
}
inline ~NCCLGroupGuard() {
PADDLE_ENFORCE(dynload::ncclGroupEnd());
mutex().unlock();
}
private:
static std::mutex& mutex() {
static std::mutex mtx;
return mtx;
}
};
} // namespace platform
} // namespace paddle

Loading…
Cancel
Save