|
|
|
@ -300,8 +300,6 @@ class ParallelExecutorPrivate {
|
|
|
|
|
std::unique_ptr<platform::EnforceNotMet> exception_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static std::mutex g_nccl_mtx_;
|
|
|
|
|
|
|
|
|
|
struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
ParallelExecutorPrivate *member_;
|
|
|
|
|
|
|
|
|
@ -327,9 +325,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
int dtype = -1;
|
|
|
|
|
size_t numel = 0;
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> g(g_nccl_mtx_);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGroupStart());
|
|
|
|
|
platform::NCCLGroupGuard guard;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
|
|
|
|
|
auto &p = member_->places_[i];
|
|
|
|
@ -355,7 +351,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream()));
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGroupEnd());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|