|
|
@ -46,14 +46,17 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
virtual ~NCCLOpHandleBase() {
|
|
|
|
virtual ~NCCLOpHandleBase() {
|
|
|
|
for (auto& ev : inter_events_) {
|
|
|
|
for (auto& ev : inter_events_) {
|
|
|
|
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto& ev : exter_events_) {
|
|
|
|
for (auto& ev : exter_events_) {
|
|
|
|
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventDestroy(ev.second));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void SetRunEnv(int run_order, bool use_hierarchical_allreduce) {
|
|
|
|
void SetRunEnv(int run_order, bool use_hierarchical_allreduce) {
|
|
|
|
PADDLE_ENFORCE(run_order >= 0, "run_order must >= 0");
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
|
|
|
run_order, 0,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The argument run_order must be >= 0, but got %d.", run_order));
|
|
|
|
run_order_ = run_order;
|
|
|
|
run_order_ = run_order;
|
|
|
|
use_hierarchical_allreduce_ = use_hierarchical_allreduce;
|
|
|
|
use_hierarchical_allreduce_ = use_hierarchical_allreduce;
|
|
|
|
|
|
|
|
|
|
|
@ -74,8 +77,11 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(places_.size() == 1,
|
|
|
|
PADDLE_ENFORCE_EQ(places_.size(), 1,
|
|
|
|
"HierarchicalAllReduce run one proc with one card mode.");
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"HierarchicalAllReduce can only run "
|
|
|
|
|
|
|
|
"one proccess with one card mode, but got %d cards.",
|
|
|
|
|
|
|
|
places_.size()));
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& p : places_) {
|
|
|
|
for (auto& p : places_) {
|
|
|
|
auto ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order);
|
|
|
|
auto ctxs = nccl_ctxs_->GetHierarchicalInterCtx(run_order);
|
|
|
@ -88,11 +94,11 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(cudaSetDevice(dev_id));
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaSetDevice(dev_id));
|
|
|
|
PADDLE_ENFORCE(cudaEventCreateWithFlags(&inter_events_[dev_id],
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
|
|
|
|
cudaEventDisableTiming));
|
|
|
|
&inter_events_[dev_id], cudaEventDisableTiming));
|
|
|
|
PADDLE_ENFORCE(cudaEventCreateWithFlags(&exter_events_[dev_id],
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventCreateWithFlags(
|
|
|
|
cudaEventDisableTiming));
|
|
|
|
&exter_events_[dev_id], cudaEventDisableTiming));
|
|
|
|
VLOG(10) << "Create events on dev_id:" << dev_id
|
|
|
|
VLOG(10) << "Create events on dev_id:" << dev_id
|
|
|
|
<< ", inter_event:" << &inter_events_[dev_id]
|
|
|
|
<< ", inter_event:" << &inter_events_[dev_id]
|
|
|
|
<< ", exter_event:" << &exter_events_[dev_id];
|
|
|
|
<< ", exter_event:" << &exter_events_[dev_id];
|
|
|
@ -102,7 +108,10 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
void FlatNCCLAllReduce(platform::Place place, const void* sendbuff,
|
|
|
|
void FlatNCCLAllReduce(platform::Place place, const void* sendbuff,
|
|
|
|
void* recvbuff, size_t count, ncclDataType_t datatype,
|
|
|
|
void* recvbuff, size_t count, ncclDataType_t datatype,
|
|
|
|
ncclRedOp_t op) {
|
|
|
|
ncclRedOp_t op) {
|
|
|
|
PADDLE_ENFORCE(run_order_ >= 0, "run_order must > 0");
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
|
|
|
run_order_, 0,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The argument run_order_ must be >= 0, but got %d.", run_order_));
|
|
|
|
auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
|
|
|
|
auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
|
|
|
|
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
|
|
|
|
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
|
|
|
|
auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
|
|
|
|
auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
|
|
|
@ -113,14 +122,17 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
<< ", dev_id:" << dev_id << ", dtype:" << datatype
|
|
|
|
<< ", dev_id:" << dev_id << ", dtype:" << datatype
|
|
|
|
<< ", place:" << place;
|
|
|
|
<< ", place:" << place;
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
|
|
|
|
sendbuff, recvbuff, count, datatype, op, comm, stream));
|
|
|
|
sendbuff, recvbuff, count, datatype, op, comm, stream));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NCCLAllReduce(platform::Place place, const void* sendbuff,
|
|
|
|
void NCCLAllReduce(platform::Place place, const void* sendbuff,
|
|
|
|
void* recvbuff, size_t count, ncclDataType_t datatype,
|
|
|
|
void* recvbuff, size_t count, ncclDataType_t datatype,
|
|
|
|
ncclRedOp_t op) {
|
|
|
|
ncclRedOp_t op) {
|
|
|
|
PADDLE_ENFORCE(run_order_ >= 0, "run_order must > 0");
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
|
|
|
run_order_, 0,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The argument run_order_ must be >= 0, but got %d.", run_order_));
|
|
|
|
if (!use_hierarchical_allreduce_) {
|
|
|
|
if (!use_hierarchical_allreduce_) {
|
|
|
|
FlatNCCLAllReduce(place, sendbuff, recvbuff, count, datatype, op);
|
|
|
|
FlatNCCLAllReduce(place, sendbuff, recvbuff, count, datatype, op);
|
|
|
|
return;
|
|
|
|
return;
|
|
|
@ -132,7 +144,10 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
void HierarchicalAllReduce(platform::Place place, const void* sendbuff,
|
|
|
|
void HierarchicalAllReduce(platform::Place place, const void* sendbuff,
|
|
|
|
void* recvbuff, size_t count,
|
|
|
|
void* recvbuff, size_t count,
|
|
|
|
ncclDataType_t datatype, ncclRedOp_t op) {
|
|
|
|
ncclDataType_t datatype, ncclRedOp_t op) {
|
|
|
|
PADDLE_ENFORCE(run_order_ >= 0, "run_order must > 0");
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
|
|
|
run_order_, 0,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"The argument run_order_ must be >= 0, but got %d.", run_order_));
|
|
|
|
InterReduce(place, sendbuff, recvbuff, count, datatype, op);
|
|
|
|
InterReduce(place, sendbuff, recvbuff, count, datatype, op);
|
|
|
|
// When a trainer is not in exter allreduce ring
|
|
|
|
// When a trainer is not in exter allreduce ring
|
|
|
|
// they need not to call this.
|
|
|
|
// they need not to call this.
|
|
|
@ -157,14 +172,13 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
<< ", dtype:" << datatype << ", place:" << place
|
|
|
|
<< ", dtype:" << datatype << ", place:" << place
|
|
|
|
<< ", stream:" << stream;
|
|
|
|
<< ", stream:" << stream;
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclReduce(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduce(
|
|
|
|
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream));
|
|
|
|
sendbuff, recvbuff, count, datatype, ncclSum, 0, comm, stream));
|
|
|
|
|
|
|
|
|
|
|
|
cudaEventRecord(inter_events_.at(dev_id), stream);
|
|
|
|
cudaEventRecord(inter_events_.at(dev_id), stream);
|
|
|
|
|
|
|
|
|
|
|
|
if (FLAGS_sync_nccl_allreduce) {
|
|
|
|
if (FLAGS_sync_nccl_allreduce) {
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream),
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
|
|
|
|
"sync HierarchicalAllReduce inter stream error");
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -172,7 +186,9 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
void* recvbuff, size_t count, ncclDataType_t datatype,
|
|
|
|
void* recvbuff, size_t count, ncclDataType_t datatype,
|
|
|
|
ncclRedOp_t op) {
|
|
|
|
ncclRedOp_t op) {
|
|
|
|
auto nccl_ctxs = nccl_ctxs_->GetHierarchicalExterCtx(run_order_);
|
|
|
|
auto nccl_ctxs = nccl_ctxs_->GetHierarchicalExterCtx(run_order_);
|
|
|
|
PADDLE_ENFORCE(nccl_ctxs_, "can't get exter %d nccl_ctxs", run_order_);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
|
|
|
nccl_ctxs_, platform::errors::NotFound(
|
|
|
|
|
|
|
|
"Can't get exter %d nccl contexts.", run_order_));
|
|
|
|
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
|
|
|
|
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
|
|
|
|
auto& nccl_ctx = nccl_ctxs->at(dev_id);
|
|
|
|
auto& nccl_ctx = nccl_ctxs->at(dev_id);
|
|
|
|
auto stream = nccl_ctx.stream();
|
|
|
|
auto stream = nccl_ctx.stream();
|
|
|
@ -185,14 +201,13 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
|
|
|
|
|
|
|
|
cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0);
|
|
|
|
cudaStreamWaitEvent(stream, inter_events_.at(dev_id), 0);
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
|
|
|
|
sendbuff, recvbuff, count, datatype, op, comm, stream));
|
|
|
|
sendbuff, recvbuff, count, datatype, op, comm, stream));
|
|
|
|
|
|
|
|
|
|
|
|
cudaEventRecord(exter_events_.at(dev_id), stream);
|
|
|
|
cudaEventRecord(exter_events_.at(dev_id), stream);
|
|
|
|
|
|
|
|
|
|
|
|
if (FLAGS_sync_nccl_allreduce) {
|
|
|
|
if (FLAGS_sync_nccl_allreduce) {
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream),
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
|
|
|
|
"sync HierarchicalAllReduce exter stream error");
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -210,8 +225,8 @@ class NCCLOpHandleBase : public OpHandleBase {
|
|
|
|
<< ", stream:" << stream;
|
|
|
|
<< ", stream:" << stream;
|
|
|
|
|
|
|
|
|
|
|
|
cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
|
|
|
|
cudaStreamWaitEvent(stream, exter_events_.at(dev_id), 0);
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(sendbuff, count, datatype, 0,
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
|
|
|
|
comm, stream));
|
|
|
|
sendbuff, count, datatype, 0, comm, stream));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|