|
|
|
@ -246,7 +246,7 @@ struct FetchOpHandle : public OpHandle {
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
|
public:
|
|
|
|
|
explicit ParallelExecutorPrivate(size_t num_threads)
|
|
|
|
|
: pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
|
: pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
|
|
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
|
|
|
|
|
|
@ -365,7 +365,7 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> g(g_nccl_mtx_);
|
|
|
|
|
|
|
|
|
|
platform::dynload::ncclGroupStart();
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGroupStart());
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
|
|
|
|
|
auto &p = member_->places_[i];
|
|
|
|
@ -383,11 +383,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &nccl_ctx = member_->communication_streams_.at(dev_id);
|
|
|
|
|
platform::dynload::ncclAllReduce(
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
|
|
|
|
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream());
|
|
|
|
|
nccl_ctx.comm, nccl_ctx.stream()));
|
|
|
|
|
}
|
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGroupEnd());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|