|
|
|
@ -52,14 +52,16 @@ inline double GetCurrentUS() {
|
|
|
|
|
return 1e+6 * time.tv_sec + time.tv_usec;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::once_flag Communicator::init_flag_;
|
|
|
|
|
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
|
|
|
|
|
|
|
|
|
|
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
const RpcCtxMap &recv_varname_to_ctx,
|
|
|
|
|
Scope *recv_scope)
|
|
|
|
|
: send_varname_to_ctx_(send_varname_to_ctx),
|
|
|
|
|
recv_varname_to_ctx_(recv_varname_to_ctx),
|
|
|
|
|
recv_scope_(recv_scope) {
|
|
|
|
|
Scope *recv_scope) {
|
|
|
|
|
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
|
|
|
|
|
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
|
|
|
|
|
recv_scope_ = std::move(recv_scope);
|
|
|
|
|
|
|
|
|
|
// get all send information from graph, build vars_to_send
|
|
|
|
|
VLOG(0) << "communicator_independent_recv_thread: "
|
|
|
|
|
<< FLAGS_communicator_independent_recv_thread;
|
|
|
|
@ -98,7 +100,51 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Communicator::~Communicator() {
|
|
|
|
|
void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
|
|
|
|
|
Scope *param_scope) {
|
|
|
|
|
using RpcCtxMap = operators::distributed::RpcCtxMap;
|
|
|
|
|
VLOG(3) << "ProcessGraph";
|
|
|
|
|
RpcCtxMap send_varname_to_ctx;
|
|
|
|
|
RpcCtxMap recv_varname_to_ctx;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
VLOG(3) << "node name " << op->Type();
|
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
|
auto send_var_name = op->Input("X")[0];
|
|
|
|
|
auto send_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
|
op->GetNullableAttr("send_varnames"));
|
|
|
|
|
auto epmap =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
|
|
|
|
|
auto height_section =
|
|
|
|
|
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
|
|
|
|
|
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
|
|
|
|
|
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
|
|
|
|
|
send_var_name, send_varnames, epmap, height_section, trainer_id);
|
|
|
|
|
VLOG(3) << "find and init an send op: "
|
|
|
|
|
<< send_varname_to_ctx[send_var_name];
|
|
|
|
|
} else if (op->Type() == "recv") {
|
|
|
|
|
auto do_not_run = boost::get<int>(op->GetNullableAttr("do_not_run"));
|
|
|
|
|
PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!");
|
|
|
|
|
auto recv_var_name = op->Output("Out")[0];
|
|
|
|
|
auto recv_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
|
op->GetNullableAttr("recv_varnames"));
|
|
|
|
|
auto epmap =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
|
|
|
|
|
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
|
|
|
|
|
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
|
|
|
|
|
recv_var_name, recv_varnames, epmap, {}, trainer_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// init communicator here
|
|
|
|
|
if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) {
|
|
|
|
|
LOG(WARNING) << "no var need to send and recv!!";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
operators::distributed::AsyncCommunicator::InitImpl(
|
|
|
|
|
send_varname_to_ctx, recv_varname_to_ctx, param_scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AsyncCommunicator::~AsyncCommunicator() {
|
|
|
|
|
if (FLAGS_v >= 3) {
|
|
|
|
|
std::string msg("~Communicator");
|
|
|
|
|
fwrite(msg.c_str(), msg.length(), 1, stdout);
|
|
|
|
@ -112,7 +158,7 @@ Communicator::~Communicator() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::SendThread() {
|
|
|
|
|
void AsyncCommunicator::SendThread() {
|
|
|
|
|
VLOG(3) << "SendThread start!";
|
|
|
|
|
while (running_) {
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
@ -175,50 +221,12 @@ void Communicator::SendThread() {
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "run send graph use time "
|
|
|
|
|
<< after_run_send_graph - before_run_send_graph;
|
|
|
|
|
RecvNonIndependent();
|
|
|
|
|
Recv();
|
|
|
|
|
}
|
|
|
|
|
VLOG(0) << "communicator stopped, send thread exit";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvNonIndependent() {
|
|
|
|
|
if (FLAGS_communicator_independent_recv_thread) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto grad_num = grad_num_.load();
|
|
|
|
|
if (grad_num > 0) {
|
|
|
|
|
RecvAll();
|
|
|
|
|
grad_num_.store(0);
|
|
|
|
|
} else {
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvAll() {
|
|
|
|
|
VLOG(3) << "parallel run recv graph";
|
|
|
|
|
if (!running_) return;
|
|
|
|
|
auto before_send = GetCurrentUS();
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
|
task_futures.reserve(recv_varname_to_ctx_.size());
|
|
|
|
|
for (auto &iter : recv_varname_to_ctx_) {
|
|
|
|
|
auto recv_task = [this, &iter] {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(4) << "recv var " << var_name;
|
|
|
|
|
auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
|
if (!FLAGS_communicator_fake_rpc) {
|
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
|
|
|
|
|
}
|
|
|
|
|
for (auto &task : task_futures) {
|
|
|
|
|
task.wait();
|
|
|
|
|
}
|
|
|
|
|
auto after_recv = GetCurrentUS();
|
|
|
|
|
VLOG(1) << "run recv graph use time " << after_recv - before_send;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvThread() {
|
|
|
|
|
void AsyncCommunicator::RecvThread() {
|
|
|
|
|
VLOG(3) << "RecvThread start!";
|
|
|
|
|
while (running_) {
|
|
|
|
|
auto grad_num = grad_num_.load();
|
|
|
|
@ -233,7 +241,7 @@ void Communicator::RecvThread() {
|
|
|
|
|
VLOG(0) << "communicator stopped, recv thread exit";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::Send(const std::string &var_name,
|
|
|
|
|
void AsyncCommunicator::Send(const std::string &var_name,
|
|
|
|
|
const framework::Scope &scope) {
|
|
|
|
|
VLOG(3) << "communicator send " << var_name;
|
|
|
|
|
// push var into send queue by var_name
|
|
|
|
@ -255,56 +263,45 @@ void Communicator::Send(const std::string &var_name,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::Init(const paddle::framework::ProgramDesc &program,
|
|
|
|
|
Scope *param_scope) {
|
|
|
|
|
using RpcCtxMap = operators::distributed::RpcCtxMap;
|
|
|
|
|
VLOG(3) << "ProcessGraph";
|
|
|
|
|
RpcCtxMap send_varname_to_ctx;
|
|
|
|
|
RpcCtxMap recv_varname_to_ctx;
|
|
|
|
|
for (auto *op : program.Block(0).AllOps()) {
|
|
|
|
|
VLOG(3) << "node name " << op->Type();
|
|
|
|
|
if (op->Type() == "send") {
|
|
|
|
|
auto send_var_name = op->Input("X")[0];
|
|
|
|
|
auto send_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
|
op->GetNullableAttr("send_varnames"));
|
|
|
|
|
auto epmap =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
|
|
|
|
|
auto height_section =
|
|
|
|
|
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
|
|
|
|
|
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
|
|
|
|
|
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
|
|
|
|
|
send_var_name, send_varnames, epmap, height_section, trainer_id);
|
|
|
|
|
VLOG(3) << "find and init an send op: "
|
|
|
|
|
<< send_varname_to_ctx[send_var_name];
|
|
|
|
|
} else if (op->Type() == "recv") {
|
|
|
|
|
auto do_not_run = boost::get<int>(op->GetNullableAttr("do_not_run"));
|
|
|
|
|
PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!");
|
|
|
|
|
auto recv_var_name = op->Output("Out")[0];
|
|
|
|
|
auto recv_varnames = boost::get<std::vector<std::string>>(
|
|
|
|
|
op->GetNullableAttr("recv_varnames"));
|
|
|
|
|
auto epmap =
|
|
|
|
|
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
|
|
|
|
|
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
|
|
|
|
|
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
|
|
|
|
|
recv_var_name, recv_varnames, epmap, {}, trainer_id);
|
|
|
|
|
}
|
|
|
|
|
void AsyncCommunicator::Recv() {
|
|
|
|
|
if (FLAGS_communicator_independent_recv_thread) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// init communicator here
|
|
|
|
|
if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) {
|
|
|
|
|
LOG(WARNING) << "no var need to send and recv!!";
|
|
|
|
|
auto grad_num = grad_num_.load();
|
|
|
|
|
if (grad_num > 0) {
|
|
|
|
|
RecvAll();
|
|
|
|
|
grad_num_.store(0);
|
|
|
|
|
} else {
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
|
|
}
|
|
|
|
|
operators::distributed::Communicator::Init(send_varname_to_ctx,
|
|
|
|
|
recv_varname_to_ctx, param_scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Communicator *Communicator::GetInstance() { return communicator_.get(); }
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Communicator> Communicator::GetInstantcePtr() {
|
|
|
|
|
return communicator_;
|
|
|
|
|
void AsyncCommunicator::RecvAll() {
|
|
|
|
|
VLOG(3) << "parallel run recv graph";
|
|
|
|
|
if (!running_) return;
|
|
|
|
|
auto before_send = GetCurrentUS();
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
|
task_futures.reserve(recv_varname_to_ctx_.size());
|
|
|
|
|
for (auto &iter : recv_varname_to_ctx_) {
|
|
|
|
|
auto recv_task = [this, &iter] {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(4) << "recv var " << var_name;
|
|
|
|
|
auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
|
if (!FLAGS_communicator_fake_rpc) {
|
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
|
|
|
|
|
}
|
|
|
|
|
for (auto &task : task_futures) {
|
|
|
|
|
task.wait();
|
|
|
|
|
}
|
|
|
|
|
auto after_recv = GetCurrentUS();
|
|
|
|
|
VLOG(1) << "run recv graph use time " << after_recv - before_send;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::Start() {
|
|
|
|
|
void AsyncCommunicator::Start() {
|
|
|
|
|
VLOG(0) << "Communicator start";
|
|
|
|
|
if (!communicator_) {
|
|
|
|
|
VLOG(0) << "Communicator is not inited, do nothing";
|
|
|
|
@ -313,15 +310,15 @@ void Communicator::Start() {
|
|
|
|
|
running_ = true;
|
|
|
|
|
// start send and recv thread
|
|
|
|
|
send_thread_.reset(
|
|
|
|
|
new std::thread(std::bind(&Communicator::SendThread, this)));
|
|
|
|
|
new std::thread(std::bind(&AsyncCommunicator::SendThread, this)));
|
|
|
|
|
if (FLAGS_communicator_independent_recv_thread) {
|
|
|
|
|
recv_thread_.reset(
|
|
|
|
|
new std::thread(std::bind(&Communicator::RecvThread, this)));
|
|
|
|
|
new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::Stop() {
|
|
|
|
|
void AsyncCommunicator::Stop() {
|
|
|
|
|
VLOG(0) << "Communicator stop";
|
|
|
|
|
running_ = false;
|
|
|
|
|
if (!communicator_) {
|
|
|
|
|