Avoid init_nccl for every steps.

shanyi15-patch-2
Xin Pan 7 years ago
parent 158d56743f
commit d054cfeae6

@ -16,5 +16,44 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace platform {} // namespace platform namespace platform {
namespace {
// TODO(panyx0718): Where to destroy them.
std::unique_ptr<std::vector<ncclComm_t>> global_comms;
std::unique_ptr<std::unordered_map<int, int>> comm_id_map;
bool inited = false;
size_t last_num_gpus = -1;
}
int Communicator::GetCommId(int device_id) const {
return comm_id_map->at(device_id);
}
void Communicator::InitAll(const std::vector<int>& gpus) {
if (inited && last_num_gpus == gpus.size()) {
return;
}
last_num_gpus = gpus.size();
if (global_comms) {
for (size_t i = 0; i < global_comms->size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy((*global_comms)[i]);
}
}
global_comms.reset(new std::vector<ncclComm_t>());
comm_id_map.reset(new std::unordered_map<int, int>());
global_comms->resize(gpus.size());
for (size_t i = 0; i < gpus.size(); ++i) {
(*comm_id_map)[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(global_comms->data(), gpus.size(), gpus.data()));
inited = true;
}
const std::vector<ncclComm_t>& Communicator::comms() const {
return *global_comms;
}
} // namespace platform
} // namespace paddle } // namespace paddle

@ -29,39 +29,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
constexpr int kInvalidGPUId = -1; constexpr int kInvalidGPUId = -1;
struct Communicator { struct Communicator {
std::vector<ncclComm_t> comms_;
std::unordered_map<int, int> comm_id_map_;
bool inited_;
Communicator() {} Communicator() {}
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); } int GetCommId(int device_id) const;
void InitAll(const std::vector<int>& gpus) {
comms_.resize(gpus.size());
inited_ = false;
for (size_t i = 0; i < gpus.size(); ++i) {
comm_id_map_[gpus[i]] = i;
}
PADDLE_ENFORCE(
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
inited_ = true;
}
~Communicator() { void InitAll(const std::vector<int>& gpus);
if (inited_) {
for (size_t i = 0; i < comms_.size(); ++i) {
// FIXME(dzh) : PADDLE_ENFORCE return void
dynload::ncclCommDestroy(comms_[i]);
}
}
}
DISABLE_COPY_AND_ASSIGN(Communicator); const std::vector<ncclComm_t>& comms() const;
}; };
} // namespace platform } // namespace platform

@ -78,7 +78,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_, outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
comm->comms_[idx], stream)); comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " VLOG(1) << "gpu : "
@ -127,7 +127,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
if (root == platform::kInvalidGPUId) { if (root == platform::kInvalidGPUId) {
root = hasher(ins_names[i]) % comm->comms_.size(); root = hasher(ins_names[i]) % comm->comms().size();
} }
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == gpu_id) { if (root == gpu_id) {
@ -139,7 +139,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclReduce( PADDLE_ENFORCE(platform::dynload::ncclReduce(
ins[i]->data<T>(), recvbuffer, ins[i]->numel(), ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx], NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
stream)); stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
@ -176,7 +176,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
VLOG(1) << " before ncclBcast"; VLOG(1) << " before ncclBcast";
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type, (void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream)); root, comm->comms().at(idx), stream));
VLOG(1) << " after ncclBcast"; VLOG(1) << " after ncclBcast";
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
@ -190,7 +190,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(), outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream)); NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "

Loading…
Cancel
Save