|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
|
|
|
|
|
|
|
|
|
#include <gflags/gflags.h>
|
|
|
|
|
#include <chrono> // NOLINT
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
|
|
|
|
|
@ -24,6 +25,13 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/distributed/parameter_send.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(communicator_independent_recv_thread, true,
|
|
|
|
|
"use an independent to recv vars from parameter server");
|
|
|
|
|
DEFINE_int32(communicator_send_queue_size, 20,
|
|
|
|
|
"queue size to recv gradient before send");
|
|
|
|
|
DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv");
|
|
|
|
|
DEFINE_int32(communicator_thread_pool_size, 5, "wait time between each recv");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace distributed {
|
|
|
|
@ -70,6 +78,38 @@ static inline void MergeVars(const std::string &var_name,
|
|
|
|
|
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
|
|
|
|
|
std::once_flag Communicator::init_flag_;
|
|
|
|
|
|
|
|
|
|
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
const RpcCtxMap &recv_varname_to_ctx,
|
|
|
|
|
Scope *recv_scope)
|
|
|
|
|
: send_varname_to_ctx_(send_varname_to_ctx),
|
|
|
|
|
recv_varname_to_ctx_(recv_varname_to_ctx),
|
|
|
|
|
recv_scope_(recv_scope) {
|
|
|
|
|
// get all send information from graph, build vars_to_send
|
|
|
|
|
VLOG(0) << "communicator_independent_recv_thread: "
|
|
|
|
|
<< FLAGS_communicator_independent_recv_thread;
|
|
|
|
|
VLOG(0) << "communicator_send_queue_size: "
|
|
|
|
|
<< FLAGS_communicator_send_queue_size;
|
|
|
|
|
VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms;
|
|
|
|
|
VLOG(0) << "communicator_thread_pool_size: "
|
|
|
|
|
<< FLAGS_communicator_thread_pool_size;
|
|
|
|
|
send_scope_.reset(new Scope());
|
|
|
|
|
for (auto &iter : send_varname_to_ctx_) {
|
|
|
|
|
send_varname_to_queue_[iter.first] =
|
|
|
|
|
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
|
|
|
|
|
FLAGS_communicator_send_queue_size);
|
|
|
|
|
}
|
|
|
|
|
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
|
|
|
|
|
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Communicator::~Communicator() {
|
|
|
|
|
VLOG(3) << "~Communicator";
|
|
|
|
|
running_ = false;
|
|
|
|
|
if (send_thread_) send_thread_->join();
|
|
|
|
|
if (recv_thread_) recv_thread_->join();
|
|
|
|
|
VLOG(3) << "~Communicator done";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::SendThread() {
|
|
|
|
|
VLOG(3) << "SendThread start!";
|
|
|
|
|
while (running_) {
|
|
|
|
@ -105,7 +145,9 @@ void Communicator::SendThread() {
|
|
|
|
|
task_f.wait();
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "run send graph done";
|
|
|
|
|
RecvAll();
|
|
|
|
|
if (!FLAGS_communicator_independent_recv_thread) {
|
|
|
|
|
RecvAll();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -132,8 +174,8 @@ void Communicator::RecvThread() {
|
|
|
|
|
VLOG(3) << "RecvThread start!";
|
|
|
|
|
while (running_) {
|
|
|
|
|
RecvAll();
|
|
|
|
|
// TODO(qiao) need to be configuable
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
|
|
|
|
std::this_thread::sleep_for(
|
|
|
|
|
std::chrono::milliseconds(FLAGS_communicator_recv_wait_ms));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -157,8 +199,10 @@ void Communicator::Start() {
|
|
|
|
|
// start send and recv thread
|
|
|
|
|
send_thread_.reset(
|
|
|
|
|
new std::thread(std::bind(&Communicator::SendThread, this)));
|
|
|
|
|
// recv_thread_.reset(
|
|
|
|
|
// new std::thread(std::bind(&Communicator::RecvThread, this)));
|
|
|
|
|
if (FLAGS_communicator_independent_recv_thread) {
|
|
|
|
|
recv_thread_.reset(
|
|
|
|
|
new std::thread(std::bind(&Communicator::RecvThread, this)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace distributed
|
|
|
|
|