gen nccl id use socket (#29431)
parent
d72604cd46
commit
467c716963
@ -0,0 +1,201 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/var_type_traits.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/string/split.h"
|
||||
|
||||
#include "paddle/fluid/operators/collective/gen_nccl_id_op_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class GenNCCLIdOp : public framework::OperatorBase {
|
||||
public:
|
||||
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void RunImpl(const framework::Scope& scope,
|
||||
const platform::Place& dev_place) const override {
|
||||
std::vector<std::string> trainers =
|
||||
Attr<std::vector<std::string>>("trainers");
|
||||
int trainer_id = Attr<int>("trainer_id");
|
||||
std::string endpoint = trainers[trainer_id];
|
||||
|
||||
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
|
||||
"trainer_id %d is less than 0. Its "
|
||||
"valid range is [0, trainer_size)"));
|
||||
PADDLE_ENFORCE_LT(
|
||||
trainer_id, static_cast<int>(trainers.size()),
|
||||
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
|
||||
"range is [0, trainer_size)",
|
||||
trainer_id));
|
||||
|
||||
int nccl_comm_num = Attr<int>("nccl_comm_num");
|
||||
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
|
||||
int inter_nranks = Attr<int>("hierarchical_allreduce_inter_nranks");
|
||||
int inter_trainer_id = -1;
|
||||
int exter_trainer_id = -1;
|
||||
|
||||
if (use_hierarchical_allreduce) {
|
||||
PADDLE_ENFORCE_GT(
|
||||
trainers.size(), 1,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The number of collective trainers %llu <= 1", trainers.size()));
|
||||
PADDLE_ENFORCE_GT(
|
||||
inter_nranks, 1,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
|
||||
inter_nranks));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
trainers.size() % inter_nranks, 0,
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The number of trainers %llu mod inter_nranks %d is not equal 0",
|
||||
trainers.size(), inter_nranks));
|
||||
|
||||
inter_trainer_id = trainer_id % inter_nranks;
|
||||
|
||||
if (trainer_id % inter_nranks == 0) {
|
||||
exter_trainer_id = trainer_id / inter_nranks;
|
||||
}
|
||||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
for (size_t i = 0; i < trainers.size(); i++) {
|
||||
ss << trainers[i] << ",";
|
||||
}
|
||||
|
||||
VLOG(1) << "trainer_id:" << trainer_id
|
||||
<< ", use_hierarchical_allreduce:" << use_hierarchical_allreduce
|
||||
<< ", nccl_comm_num:" << nccl_comm_num
|
||||
<< ", inter_nranks:" << inter_nranks
|
||||
<< ", inter_trainer_id:" << inter_trainer_id
|
||||
<< ", exter_trainer_id:" << exter_trainer_id
|
||||
<< ", trainers:" << ss.str();
|
||||
|
||||
int server_fd = -1;
|
||||
|
||||
/// 1. init flat
|
||||
std::function<std::string(size_t)> func = platform::GetFlatNCCLVarName;
|
||||
if (trainer_id == 0) {
|
||||
// server endpoints
|
||||
std::vector<std::string> flat_endpoints;
|
||||
flat_endpoints.insert(flat_endpoints.begin(), trainers.begin() + 1,
|
||||
trainers.end());
|
||||
SendBroadCastNCCLID(flat_endpoints, nccl_comm_num, func, scope);
|
||||
} else {
|
||||
server_fd = CreateListenSocket(endpoint);
|
||||
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope);
|
||||
}
|
||||
|
||||
/// 2. hierarchical inter ncclid
|
||||
func = platform::GetHierarchicalInterNCCLVarName;
|
||||
if (inter_trainer_id == 0) {
|
||||
std::ostringstream ss;
|
||||
ss << endpoint;
|
||||
std::vector<std::string> inter_endpoints;
|
||||
for (int i = trainer_id + 1; i < trainer_id + inter_nranks &&
|
||||
i < static_cast<int>(trainers.size());
|
||||
i++) {
|
||||
ss << ",";
|
||||
inter_endpoints.push_back(trainers[i]);
|
||||
ss << trainers[i];
|
||||
}
|
||||
VLOG(1) << "Hierarchical inter ring endpoints:" << ss.str();
|
||||
|
||||
SendBroadCastNCCLID(inter_endpoints, nccl_comm_num, func, scope);
|
||||
} else if (inter_trainer_id > 0) {
|
||||
VLOG(1) << "Hierarchical inter ring";
|
||||
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope);
|
||||
}
|
||||
|
||||
/// 3. hierarchical exter ncclid
|
||||
func = platform::GetHierarchicalExterNCCLVarName;
|
||||
if (exter_trainer_id == 0) {
|
||||
std::ostringstream ss;
|
||||
std::vector<std::string> exter_endpoints;
|
||||
ss << endpoint;
|
||||
for (size_t i = inter_nranks; i < trainers.size(); i += inter_nranks) {
|
||||
ss << ",";
|
||||
exter_endpoints.push_back(trainers[i]);
|
||||
ss << trainers[i];
|
||||
}
|
||||
VLOG(1) << "Hierarchical exter ring endpoints:" << ss.str();
|
||||
|
||||
SendBroadCastNCCLID(exter_endpoints, nccl_comm_num, func, scope);
|
||||
} else if (exter_trainer_id > 0) {
|
||||
VLOG(1) << "Hierarchical exter ring";
|
||||
RecvBroadCastNCCLID(server_fd, endpoint, nccl_comm_num, func, scope);
|
||||
}
|
||||
|
||||
// close socket server
|
||||
if (trainer_id != 0) {
|
||||
CloseSocket(server_fd);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
|
||||
AddComment(R"DOC(
|
||||
GenNCCLId operator
|
||||
|
||||
For trainer 0: generate a new UniqueId and send it to all the other trainers.
|
||||
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
|
||||
)DOC");
|
||||
AddAttr<std::vector<std::string>>(
|
||||
"trainers",
|
||||
"['trainer0_ip:port', 'trainer1_ip:port', ...] "
|
||||
"list of all trainer endpoints")
|
||||
.SetDefault({});
|
||||
AddAttr<int>("trainer_id",
|
||||
"(int) "
|
||||
"The index of the trainer in distributed training.");
|
||||
AddAttr<int>("nccl_comm_num",
|
||||
"(int default 1) "
|
||||
"The number of nccl communicator num.")
|
||||
.SetDefault(1);
|
||||
AddAttr<bool>("use_hierarchical_allreduce",
|
||||
"(bool default false) "
|
||||
"Wheter to use hierarchical allreduce.")
|
||||
.SetDefault(false);
|
||||
AddAttr<int>("hierarchical_allreduce_inter_nranks",
|
||||
"(int default 1) "
|
||||
"Wheter to use hierarchical allreduce.")
|
||||
.SetDefault(-1);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
int CreateListenSocket(const std::string& ep);
|
||||
|
||||
void CloseSocket(int fd);
|
||||
|
||||
void SendBroadCastNCCLID(std::vector<std::string> servers, int nccl_comm_num,
|
||||
std::function<std::string(size_t)> func,
|
||||
const framework::Scope& scope);
|
||||
|
||||
// server listen on endpoint, then recv nccl id
|
||||
void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num,
|
||||
std::function<std::string(size_t)> func,
|
||||
const framework::Scope& scope);
|
||||
|
||||
// recv nccl id from socket
|
||||
void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num,
|
||||
std::function<std::string(size_t)> func,
|
||||
const framework::Scope& scope);
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
from launch_function_helper import wait, _find_free_port
|
||||
from multiprocessing import Pool, Process
|
||||
|
||||
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10")
|
||||
|
||||
import paddle
|
||||
from paddle.fluid import core
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
def run_gen_ncc_id(attr):
|
||||
nccl_comm_num = attr['nccl_comm_num']
|
||||
use_hallreduce = attr['use_hierarchical_allreduce']
|
||||
|
||||
startup_program = paddle.static.default_startup_program()
|
||||
main_program = paddle.static.default_main_program()
|
||||
|
||||
with paddle.static.program_guard(main_program, startup_program):
|
||||
nccl_id_var = startup_program.global_block().create_var(
|
||||
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
||||
|
||||
for i in range(1, nccl_comm_num):
|
||||
startup_program.global_block().create_var(
|
||||
name="NCCLID_{}".format(i),
|
||||
persistable=True,
|
||||
type=core.VarDesc.VarType.RAW)
|
||||
|
||||
if use_hallreduce:
|
||||
for i in range(0, nccl_comm_num):
|
||||
startup_program.global_block().create_var(
|
||||
name="Hierarchical_inter_NCCLID_{}".format(i),
|
||||
persistable=True,
|
||||
type=core.VarDesc.VarType.RAW)
|
||||
startup_program.global_block().create_var(
|
||||
name="Hierarchical_exter_NCCLID_{}".format(i),
|
||||
persistable=True,
|
||||
type=core.VarDesc.VarType.RAW)
|
||||
|
||||
startup_program.global_block().append_op(
|
||||
type="gen_nccl_id",
|
||||
inputs={},
|
||||
outputs={"NCCLID": nccl_id_var},
|
||||
attrs=attr)
|
||||
|
||||
place = paddle.CPUPlace()
|
||||
|
||||
exe = paddle.static.Executor(place)
|
||||
exe.run(startup_program)
|
||||
|
||||
|
||||
class TestGenNcclIdOp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
try:
|
||||
self._dist_ut_port_0 = int(os.environ["PADDLE_DIST_UT_PORT"])
|
||||
except Exception as e:
|
||||
self._dist_ut_port_0 = _find_free_port(set())
|
||||
|
||||
def gen_nccl_id(self, nranks=2):
|
||||
nccl_comm_num = 1
|
||||
if nranks == 2:
|
||||
use_hallreduce = False
|
||||
hallreduce_inter_nranks = -1
|
||||
elif nranks == 4:
|
||||
use_hallreduce = True
|
||||
hallreduce_inter_nranks = 2
|
||||
|
||||
port = self._dist_ut_port_0
|
||||
trainers = []
|
||||
for i in range(nranks):
|
||||
trainers.append('127.0.0.1:{}'.format(port + i))
|
||||
|
||||
attr = {
|
||||
"trainers": trainers,
|
||||
"trainer_id": 0,
|
||||
"nccl_comm_num": nccl_comm_num,
|
||||
"use_hierarchical_allreduce": use_hallreduce,
|
||||
"hierarchical_allreduce_inter_nranks": hallreduce_inter_nranks,
|
||||
}
|
||||
|
||||
procs = []
|
||||
for i in range(nranks):
|
||||
attr['trainer_id'] = i
|
||||
p = Process(target=run_gen_ncc_id, args=(attr, ))
|
||||
p.start()
|
||||
procs.append(p)
|
||||
|
||||
wait(procs, timeout=120)
|
||||
|
||||
def test_flat(self):
|
||||
print(">>> test gen flat nccl id")
|
||||
self.gen_nccl_id(2)
|
||||
print("<<< end test gen flat nccl id")
|
||||
|
||||
def test_hierarchical(self):
|
||||
print(">>> test gen hierarchical nccl id")
|
||||
self.gen_nccl_id(4)
|
||||
print("<<< end test gen hierarchical nccl id")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue