|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/imperative/nccl_context.h"
|
|
|
|
|
#include "paddle/fluid/platform/collective_helper.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
@ -115,7 +116,6 @@ void NCCLParallelContext::BcastNCCLId(ncclUniqueId *nccl_id, int root) {
|
|
|
|
|
|
|
|
|
|
void NCCLParallelContext::Init() {
|
|
|
|
|
ncclUniqueId nccl_id;
|
|
|
|
|
ncclComm_t comm;
|
|
|
|
|
if (strategy_.local_rank_ == 0) {
|
|
|
|
|
// generate the unique ncclid on the root worker
|
|
|
|
|
platform::dynload::ncclGetUniqueId(&nccl_id);
|
|
|
|
@ -128,12 +128,13 @@ void NCCLParallelContext::Init() {
|
|
|
|
|
<< " local rank: " << strategy_.local_rank_ << " gpu id: " << gpu_id;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(cudaSetDevice(gpu_id));
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
|
|
|
|
|
&comm, strategy_.nranks_, nccl_id, strategy_.local_rank_));
|
|
|
|
|
platform::NCCLComm *nccl_comm =
|
|
|
|
|
platform::NCCLCommContext::Instance().CreateNCCLComm(
|
|
|
|
|
&nccl_id, strategy_.nranks_, strategy_.local_rank_, gpu_id, 0);
|
|
|
|
|
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(pool.Get(place_));
|
|
|
|
|
dev_ctx->set_nccl_comm(comm);
|
|
|
|
|
dev_ctx->set_nccl_comm(nccl_comm->comm());
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|