|
|
|
@ -32,6 +32,9 @@ DEFINE_int32(communicator_send_queue_size, 20,
|
|
|
|
|
DEFINE_int32(communicator_max_send_grad_num_before_recv, 20,
|
|
|
|
|
"max grad num to send before recv parameters");
|
|
|
|
|
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
|
|
|
|
|
DEFINE_int32(communicator_send_wait_times, 5,
|
|
|
|
|
"times that send thread will wait if merge num does not reach "
|
|
|
|
|
"max_merge_var_num");
|
|
|
|
|
DEFINE_int32(communicator_max_merge_var_num, 20,
|
|
|
|
|
"max var num to merge and send");
|
|
|
|
|
DEFINE_bool(communicator_fake_rpc, false,
|
|
|
|
@ -101,20 +104,32 @@ void Communicator::SendThread() {
|
|
|
|
|
VLOG(3) << var_name << " merge and send";
|
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
|
size_t merged_var_num = 0;
|
|
|
|
|
while (var_queue->Size() > 0 &&
|
|
|
|
|
merged_var_num < FLAGS_communicator_max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
// only count the send number of the first var
|
|
|
|
|
if (var_name == send_varname_to_queue_.begin()->first) {
|
|
|
|
|
grad_num_.fetch_add(1, std::memory_order_relaxed);
|
|
|
|
|
size_t wait_times = 0;
|
|
|
|
|
while (merged_var_num < FLAGS_communicator_max_merge_var_num) {
|
|
|
|
|
if (var_queue->Size() == 0) {
|
|
|
|
|
VLOG(3) << "wait_times -> " << wait_times;
|
|
|
|
|
if (wait_times >= FLAGS_communicator_send_wait_times) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
|
|
|
|
wait_times++;
|
|
|
|
|
continue;
|
|
|
|
|
} else {
|
|
|
|
|
wait_times = 0;
|
|
|
|
|
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
// only count the send number of the first var
|
|
|
|
|
if (var_name == send_varname_to_queue_.begin()->first) {
|
|
|
|
|
grad_num_.fetch_add(1, std::memory_order_relaxed);
|
|
|
|
|
}
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
auto before_merge = GetCurrentUS();
|
|
|
|
|
MergeVars(var_name, vars, send_scope_.get());
|
|
|
|
|
auto after_merge = GetCurrentUS();
|
|
|
|
|
VLOG(3) << "merge " << var_name << " use time "
|
|
|
|
|
<< after_merge - before_merge;
|
|
|
|
|
VLOG(3) << "merge " << merged_var_num << " " << var_name
|
|
|
|
|
<< " use time " << after_merge - before_merge;
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
|
|
|
|
if (!FLAGS_communicator_fake_rpc) {
|
|
|
|
|