|
|
|
@ -29,7 +29,8 @@ DEFINE_bool(communicator_independent_recv_thread, true,
|
|
|
|
|
"use an independent to recv vars from parameter server");
|
|
|
|
|
DEFINE_int32(communicator_send_queue_size, 20,
|
|
|
|
|
"queue size to recv gradient before send");
|
|
|
|
|
DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv");
|
|
|
|
|
DEFINE_int32(communicator_max_send_grad_num_before_recv, 20,
|
|
|
|
|
"max grad num to send before recv parameters");
|
|
|
|
|
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
|
|
|
|
|
DEFINE_int32(communicator_max_merge_var_num, 20,
|
|
|
|
|
"max var num to merge and send");
|
|
|
|
@ -60,7 +61,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
<< FLAGS_communicator_independent_recv_thread;
|
|
|
|
|
VLOG(0) << "communicator_send_queue_size: "
|
|
|
|
|
<< FLAGS_communicator_send_queue_size;
|
|
|
|
|
VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms;
|
|
|
|
|
VLOG(0) << "communicator_max_send_grad_num_before_recv: "
|
|
|
|
|
<< FLAGS_communicator_max_send_grad_num_before_recv;
|
|
|
|
|
VLOG(0) << "communicator_thread_pool_size: "
|
|
|
|
|
<< FLAGS_communicator_thread_pool_size;
|
|
|
|
|
VLOG(0) << "communicator_max_merge_var_num: "
|
|
|
|
@ -102,6 +104,10 @@ void Communicator::SendThread() {
|
|
|
|
|
while (var_queue->Size() > 0 &&
|
|
|
|
|
merged_var_num < FLAGS_communicator_max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
// only count the send number of the first var
|
|
|
|
|
if (var_name == send_varname_to_queue_.begin()->first) {
|
|
|
|
|
grad_num_.fetch_add(1, std::memory_order_relaxed);
|
|
|
|
|
}
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
auto before_merge = GetCurrentUS();
|
|
|
|
@ -129,7 +135,7 @@ void Communicator::SendThread() {
|
|
|
|
|
}
|
|
|
|
|
auto after_run_send_graph = GetCurrentUS();
|
|
|
|
|
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
|
|
|
|
|
if (send_graph_use_time > 10) {
|
|
|
|
|
if (send_graph_use_time > 100) {
|
|
|
|
|
VLOG(1) << "run send graph use time "
|
|
|
|
|
<< after_run_send_graph - before_run_send_graph;
|
|
|
|
|
}
|
|
|
|
@ -165,9 +171,14 @@ void Communicator::RecvAll() {
|
|
|
|
|
void Communicator::RecvThread() {
|
|
|
|
|
VLOG(3) << "RecvThread start!";
|
|
|
|
|
while (running_) {
|
|
|
|
|
RecvAll();
|
|
|
|
|
std::this_thread::sleep_for(
|
|
|
|
|
std::chrono::milliseconds(FLAGS_communicator_recv_wait_ms));
|
|
|
|
|
auto grad_num = grad_num_.load();
|
|
|
|
|
if (grad_num > FLAGS_communicator_max_send_grad_num_before_recv) {
|
|
|
|
|
VLOG(1) << "current grad num " << grad_num;
|
|
|
|
|
RecvAll();
|
|
|
|
|
grad_num_.store(0);
|
|
|
|
|
} else {
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|