|
|
@ -75,10 +75,11 @@ void Communicator::SendThread() {
|
|
|
|
while (running_) {
|
|
|
|
while (running_) {
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
task_futures.reserve(send_varname_to_ctx_.size());
|
|
|
|
task_futures.reserve(send_varname_to_ctx_.size());
|
|
|
|
|
|
|
|
VLOG(3) << "run send graph";
|
|
|
|
for (auto &iter : send_varname_to_queue_) {
|
|
|
|
for (auto &iter : send_varname_to_queue_) {
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
auto &var_queue = iter.second;
|
|
|
|
auto &var_queue = iter.second;
|
|
|
|
if (var_queue->NotEmpty()) { // will block if queue is empty
|
|
|
|
if (var_queue->Size() > 0) {
|
|
|
|
auto send_task = [this, &var_name, &var_queue] {
|
|
|
|
auto send_task = [this, &var_name, &var_queue] {
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
@ -96,18 +97,20 @@ void Communicator::SendThread() {
|
|
|
|
};
|
|
|
|
};
|
|
|
|
task_futures.emplace_back(
|
|
|
|
task_futures.emplace_back(
|
|
|
|
send_threadpool_->enqueue(std::move(send_task)));
|
|
|
|
send_threadpool_->enqueue(std::move(send_task)));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
VLOG(3) << var_name << " queue empty";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &task_f : task_futures) {
|
|
|
|
for (auto &task_f : task_futures) {
|
|
|
|
task_f.wait();
|
|
|
|
task_f.wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(3) << "run send graph done";
|
|
|
|
|
|
|
|
RecvAll();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvThread() {
|
|
|
|
void Communicator::RecvAll() {
|
|
|
|
VLOG(3) << "RecvThread start!";
|
|
|
|
VLOG(3) << "parallel run recv graph";
|
|
|
|
while (running_) {
|
|
|
|
|
|
|
|
// parallel run recv graph
|
|
|
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
task_futures.reserve(recv_varname_to_ctx_.size());
|
|
|
|
task_futures.reserve(recv_varname_to_ctx_.size());
|
|
|
|
for (auto &iter : recv_varname_to_ctx_) {
|
|
|
|
for (auto &iter : recv_varname_to_ctx_) {
|
|
|
@ -117,12 +120,18 @@ void Communicator::RecvThread() {
|
|
|
|
auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
};
|
|
|
|
};
|
|
|
|
task_futures.emplace_back(
|
|
|
|
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
|
|
|
|
recv_threadpool_->enqueue(std::move(recv_task)));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &task : task_futures) {
|
|
|
|
for (auto &task : task_futures) {
|
|
|
|
task.wait();
|
|
|
|
task.wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
VLOG(3) << "run recv graph done";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvThread() {
|
|
|
|
|
|
|
|
VLOG(3) << "RecvThread start!";
|
|
|
|
|
|
|
|
while (running_) {
|
|
|
|
|
|
|
|
RecvAll();
|
|
|
|
// TODO(qiao) need to be configuable
|
|
|
|
// TODO(qiao) need to be configuable
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -136,7 +145,9 @@ void Communicator::Send(const std::string &var_name,
|
|
|
|
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
|
|
|
|
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
|
|
|
|
auto tmp_grad_var = std::make_shared<Variable>();
|
|
|
|
auto tmp_grad_var = std::make_shared<Variable>();
|
|
|
|
framework::CopyVariable(*grad_var, tmp_grad_var.get());
|
|
|
|
framework::CopyVariable(*grad_var, tmp_grad_var.get());
|
|
|
|
send_varname_to_queue_[var_name]->Push(tmp_grad_var);
|
|
|
|
auto &queue = send_varname_to_queue_.at(var_name);
|
|
|
|
|
|
|
|
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
|
|
|
|
|
|
|
|
queue->Push(tmp_grad_var);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Communicator *Communicator::GetInstance() { return communicator_.get(); }
|
|
|
|
Communicator *Communicator::GetInstance() { return communicator_.get(); }
|
|
|
@ -146,8 +157,8 @@ void Communicator::Start() {
|
|
|
|
// start send and recv thread
|
|
|
|
// start send and recv thread
|
|
|
|
send_thread_.reset(
|
|
|
|
send_thread_.reset(
|
|
|
|
new std::thread(std::bind(&Communicator::SendThread, this)));
|
|
|
|
new std::thread(std::bind(&Communicator::SendThread, this)));
|
|
|
|
recv_thread_.reset(
|
|
|
|
// recv_thread_.reset(
|
|
|
|
new std::thread(std::bind(&Communicator::RecvThread, this)));
|
|
|
|
// new std::thread(std::bind(&Communicator::RecvThread, this)));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace distributed
|
|
|
|
} // namespace distributed
|
|
|
|