|
|
|
@ -40,28 +40,21 @@ class NCCLInitOp : public framework::OperatorBase {
|
|
|
|
|
// A parallel do may not use all the gpus. For example, the batch size is 7
|
|
|
|
|
// in the last batch while we have 8 gpu. In this case, parallel_do will
|
|
|
|
|
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
auto ¶llel_scopes = scope.FindVar(Input(kParallelScopes))
|
|
|
|
|
->Get<std::vector<framework::Scope *>>();
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
std::vector<int> gpus(parallel_scopes.size());
|
|
|
|
|
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
|
|
|
|
|
gpus[i] = i;
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
|
|
|
|
|
if (scope.FindVar(name) == nullptr) {
|
|
|
|
|
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
|
|
|
|
|
platform::Communicator *comm =
|
|
|
|
|
scope.FindVar(name)->GetMutable<platform::Communicator>();
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
comm->InitAll(gpus);
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|