fix blocking problem

mixed_precision_init
Qiao Longfei 6 years ago
parent c0e5941e31
commit 63cd70a8b8

@ -75,10 +75,11 @@ void Communicator::SendThread() {
while (running_) { while (running_) {
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size()); task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
for (auto &iter : send_varname_to_queue_) { for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first; auto &var_name = iter.first;
auto &var_queue = iter.second; auto &var_queue = iter.second;
if (var_queue->NotEmpty()) { // will block if queue is empty if (var_queue->Size() > 0) {
auto send_task = [this, &var_name, &var_queue] { auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << "merge var " << var_name << " and send"; VLOG(3) << "merge var " << var_name << " and send";
std::vector<std::shared_ptr<Variable>> vars; std::vector<std::shared_ptr<Variable>> vars;
@ -96,18 +97,20 @@ void Communicator::SendThread() {
}; };
task_futures.emplace_back( task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task))); send_threadpool_->enqueue(std::move(send_task)));
} else {
VLOG(3) << var_name << " queue empty";
} }
} }
for (auto &task_f : task_futures) { for (auto &task_f : task_futures) {
task_f.wait(); task_f.wait();
} }
VLOG(3) << "run send graph done";
RecvAll();
} }
} }
void Communicator::RecvThread() { void Communicator::RecvAll() {
VLOG(3) << "RecvThread start!"; VLOG(3) << "parallel run recv graph";
while (running_) {
// parallel run recv graph
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size()); task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) { for (auto &iter : recv_varname_to_ctx_) {
@ -117,12 +120,18 @@ void Communicator::RecvThread() {
auto recv_functor = distributed::ParameterRecv<float>(); auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(iter.second, *recv_scope_); recv_functor(iter.second, *recv_scope_);
}; };
task_futures.emplace_back( task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
recv_threadpool_->enqueue(std::move(recv_task)));
} }
for (auto &task : task_futures) { for (auto &task : task_futures) {
task.wait(); task.wait();
} }
VLOG(3) << "run recv graph done";
}
void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
RecvAll();
// TODO(qiao) need to be configuable // TODO(qiao) need to be configuable
std::this_thread::sleep_for(std::chrono::milliseconds(200)); std::this_thread::sleep_for(std::chrono::milliseconds(200));
} }
@ -136,7 +145,9 @@ void Communicator::Send(const std::string &var_name,
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited"); PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
auto tmp_grad_var = std::make_shared<Variable>(); auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get()); framework::CopyVariable(*grad_var, tmp_grad_var.get());
send_varname_to_queue_[var_name]->Push(tmp_grad_var); auto &queue = send_varname_to_queue_.at(var_name);
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
queue->Push(tmp_grad_var);
} }
Communicator *Communicator::GetInstance() { return communicator_.get(); } Communicator *Communicator::GetInstance() { return communicator_.get(); }
@ -146,8 +157,8 @@ void Communicator::Start() {
// start send and recv thread // start send and recv thread
send_thread_.reset( send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this))); new std::thread(std::bind(&Communicator::SendThread, this)));
recv_thread_.reset( // recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this))); // new std::thread(std::bind(&Communicator::RecvThread, this)));
} }
} // namespace distributed } // namespace distributed

@ -43,37 +43,36 @@ class BlockingQueue {
} }
bool Push(const T& elem) { bool Push(const T& elem) {
{
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_); PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem); queue_.push_back(elem);
recv_cv_.notify_one(); }
cv_.notify_one();
return true; return true;
} }
bool Push(T&& elem) { bool Push(T&& elem) {
{
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
send_cv_.wait(lock, [&] { return queue_.size() < capacity_; }); cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_); PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem)); queue_.emplace_back(std::move(elem));
recv_cv_.notify_one(); }
cv_.notify_one();
return true; return true;
} }
T Pop() { T Pop() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); }); cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front())); T rc(std::move(queue_.front()));
queue_.pop_front(); queue_.pop_front();
cv_.notify_one();
return rc; return rc;
} }
bool NotEmpty() {
std::unique_lock<std::mutex> lock(mutex_);
recv_cv_.wait(lock, [=] { return !queue_.empty(); });
return true;
}
size_t Cap() const { size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return capacity_; return capacity_;
@ -89,8 +88,7 @@ class BlockingQueue {
std::deque<T> queue_; std::deque<T> queue_;
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable recv_cv_; std::condition_variable cv_;
std::condition_variable send_cv_;
}; };
using RpcCtxMap = std::unordered_map<std::string, RpcContext>; using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
@ -127,6 +125,8 @@ class Communicator {
void Send(const std::string& var_name, const framework::Scope& scope); void Send(const std::string& var_name, const framework::Scope& scope);
private: private:
// recv all parameter
void RecvAll();
void SendThread(); void SendThread();
void RecvThread(); void RecvThread();

@ -41,6 +41,7 @@ using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in";
framework::Scope *local_scope = scope.NewTmpScope(); framework::Scope *local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
@ -90,6 +91,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
} }
delete local_scope; delete local_scope;
VLOG(3) << "ParameterRecv out";
} }
template struct ParameterRecv<float>; template struct ParameterRecv<float>;

@ -48,12 +48,15 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) { if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, ""); PADDLE_ENFORCE_EQ(ins.size(), 1, "");
// auto send_functor = distributed::ParameterSend<float>(); /*
// auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, auto send_functor = distributed::ParameterSend<float>();
// epmap, auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
// height_sections); height_sections);
// send_functor(rpc_ctx, scope, static_cast<bool>(sync_send)); send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
*/
VLOG(3) << "send " << ins[0];
distributed::Communicator::GetInstance()->Send(ins[0], scope); distributed::Communicator::GetInstance()->Send(ins[0], scope);
VLOG(3) << "send " << ins[0] << " done";
} else { } else {
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();

Loading…
Cancel
Save