|
|
|
@ -340,6 +340,8 @@ ncclDataType_t ToNCCLDataType(std::type_index type) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::mutex g_nccl_mtx_;
|
|
|
|
|
|
|
|
|
|
struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
ParallelExecutorPrivate *member_;
|
|
|
|
|
|
|
|
|
@ -361,6 +363,8 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
int dtype = -1;
|
|
|
|
|
size_t numel = 0;
|
|
|
|
|
|
|
|
|
|
std::lock_guard<std::mutex> g(g_nccl_mtx_);
|
|
|
|
|
|
|
|
|
|
platform::dynload::ncclGroupStart();
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < member_->local_scopes_.size(); ++i) {
|
|
|
|
|