|
|
|
@ -52,17 +52,17 @@ class GenNCCLIdOp : public framework::OperatorBase {
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
void GenerateAndSend(framework::Scope* scope,
|
|
|
|
void GenerateAndSend(framework::Scope* scope,
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
auto var = scope->FindVar("NCCLID");
|
|
|
|
auto var = scope->FindVar(NCCL_ID_VARNAME);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
auto id = var->GetMutable<ncclUniqueId>();
|
|
|
|
auto id = var->GetMutable<ncclUniqueId>();
|
|
|
|
platform::dynload::ncclGetUniqueId(id);
|
|
|
|
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> endpoint_list =
|
|
|
|
std::vector<std::string> endpoint_list =
|
|
|
|
Attr<std::vector<std::string>>("endpoint_list");
|
|
|
|
Attr<std::vector<std::string>>("endpoint_list");
|
|
|
|
detail::RPCClient client;
|
|
|
|
detail::RPCClient client;
|
|
|
|
for (auto& ep : endpoint_list) {
|
|
|
|
for (auto& ep : endpoint_list) {
|
|
|
|
VLOG(3) << "sending nccl id to " << ep;
|
|
|
|
VLOG(3) << "sending nccl id to " << ep;
|
|
|
|
client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID");
|
|
|
|
client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
client.Wait();
|
|
|
|
client.Wait();
|
|
|
|
VLOG(3) << "sending completed...";
|
|
|
|
VLOG(3) << "sending completed...";
|
|
|
|
@ -71,6 +71,9 @@ class GenNCCLIdOp : public framework::OperatorBase {
|
|
|
|
void GetIdByServer(framework::Scope* scope,
|
|
|
|
void GetIdByServer(framework::Scope* scope,
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
const platform::DeviceContext& dev_ctx) const {
|
|
|
|
std::string endpoint = Attr<std::string>("endpoint");
|
|
|
|
std::string endpoint = Attr<std::string>("endpoint");
|
|
|
|
|
|
|
|
// NOTE: Can not use unique_ptr here because the default
|
|
|
|
|
|
|
|
// deleter will call GRPC Server's base class's dtor and
|
|
|
|
|
|
|
|
// that will cause a wired crash.
|
|
|
|
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
|
|
|
|
rpc_service_ = new detail::AsyncGRPCServer(endpoint, true);
|
|
|
|
framework::ProgramDesc empty_program;
|
|
|
|
framework::ProgramDesc empty_program;
|
|
|
|
framework::Executor executor(dev_ctx.GetPlace());
|
|
|
|
framework::Executor executor(dev_ctx.GetPlace());
|
|
|
|
|