|
|
|
@ -30,7 +30,11 @@ DEFINE_bool(communicator_independent_recv_thread, true,
|
|
|
|
|
DEFINE_int32(communicator_send_queue_size, 20,
|
|
|
|
|
"queue size to recv gradient before send");
|
|
|
|
|
DEFINE_int32(communicator_recv_wait_ms, 200, "wait time between each recv");
|
|
|
|
|
DEFINE_int32(communicator_thread_pool_size, 5, "wait time between each recv");
|
|
|
|
|
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
|
|
|
|
|
DEFINE_int32(communicator_max_merge_var_num, 20,
|
|
|
|
|
"max var num to merge and send");
|
|
|
|
|
DEFINE_bool(communicator_fake_rpc, false,
|
|
|
|
|
"fake mode does not really send any thing");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -92,6 +96,9 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
VLOG(0) << "communicator_recv_wait_ms: " << FLAGS_communicator_recv_wait_ms;
|
|
|
|
|
VLOG(0) << "communicator_thread_pool_size: "
|
|
|
|
|
<< FLAGS_communicator_thread_pool_size;
|
|
|
|
|
VLOG(0) << "communicator_max_merge_var_num"
|
|
|
|
|
<< FLAGS_communicator_max_merge_var_num;
|
|
|
|
|
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
|
|
|
|
|
send_scope_.reset(new Scope());
|
|
|
|
|
for (auto &iter : send_varname_to_ctx_) {
|
|
|
|
|
send_varname_to_queue_[iter.first] =
|
|
|
|
@ -123,17 +130,18 @@ void Communicator::SendThread() {
|
|
|
|
|
auto send_task = [this, &var_name, &var_queue] {
|
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
|
// TODO(qiao): need to be configurable
|
|
|
|
|
const size_t max_merge_var_num = 20;
|
|
|
|
|
size_t merged_var_num = 0;
|
|
|
|
|
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
|
|
|
|
|
while (var_queue->Size() > 0 &&
|
|
|
|
|
merged_var_num < FLAGS_communicator_max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
MergeVars(var_name, vars, send_scope_.get());
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
|
|
|
|
send_functor(ctx, *send_scope_, true);
|
|
|
|
|
if (!FLAGS_communicator_fake_rpc) {
|
|
|
|
|
send_functor(ctx, *send_scope_, true);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(
|
|
|
|
|
send_threadpool_->enqueue(std::move(send_task)));
|
|
|
|
@ -160,7 +168,9 @@ void Communicator::RecvAll() {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(3) << "recv var " << var_name;
|
|
|
|
|
auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
|
if (!FLAGS_communicator_fake_rpc) {
|
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
|
|
|
|
|
}
|
|
|
|
|