|
|
|
@ -19,6 +19,8 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
static constexpr char kParallelScopes[] = "parallel_scopes";
|
|
|
|
|
|
|
|
|
|
// NCCLinitOp
|
|
|
|
|
class NCCLInitOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
@ -29,24 +31,37 @@ class NCCLInitOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
void Run(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)),
|
|
|
|
|
"Can not find variable '%s' in the scope.",
|
|
|
|
|
kParallelScopes);
|
|
|
|
|
const auto &name = Output("Communicator");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
|
|
|
|
|
"Can not find variable '%s' in the scope.", name);
|
|
|
|
|
|
|
|
|
|
int count = platform::GetCUDADeviceCount();
|
|
|
|
|
std::vector<int> gpus(count);
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
// A parallel do may not use all the gpus. For example, the batch size is 7
|
|
|
|
|
// in the last batch while we have 8 gpu. In this case, parallel_do will
|
|
|
|
|
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
auto ¶llel_scopes = scope.FindVar(Input(kParallelScopes))
|
|
|
|
|
->Get<std::vector<framework::Scope *>>();
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
std::vector<int> gpus(parallel_scopes.size());
|
|
|
|
|
for (int i = 0; i < static_cast<int>(parallel_scopes.size()); ++i) {
|
|
|
|
|
gpus[i] = i;
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus.");
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
|
|
|
|
|
if (scope.FindVar(name) == nullptr) {
|
|
|
|
|
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
|
|
|
|
|
}
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
|
|
|
|
|
platform::Communicator *comm =
|
|
|
|
|
scope.FindVar(name)->GetMutable<platform::Communicator>();
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
comm->InitAll(gpus);
|
|
|
|
|
LOG(INFO) << "---------------";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -70,6 +85,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput(kParallelScopes, "The working place of parallel do.");
|
|
|
|
|
AddOutput("Communicator",
|
|
|
|
|
"Create Communicator for communicating between gpus");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|