|
|
|
@ -40,6 +40,8 @@ DEFINE_int32(communicator_max_merge_var_num, 20,
|
|
|
|
|
"max var num to merge and send");
|
|
|
|
|
DEFINE_bool(communicator_fake_rpc, false,
|
|
|
|
|
"fake mode does not really send any thing");
|
|
|
|
|
DEFINE_bool(communicator_merge_sparse_grad, true,
|
|
|
|
|
"merge sparse gradient before sending");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -73,6 +75,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
|
|
|
|
VLOG(0) << "communicator_max_merge_var_num: "
|
|
|
|
|
<< FLAGS_communicator_max_merge_var_num;
|
|
|
|
|
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
|
|
|
|
|
VLOG(0) << "communicator_merge_sparse_grad: "
|
|
|
|
|
<< FLAGS_communicator_merge_sparse_grad;
|
|
|
|
|
send_scope_.reset(new Scope());
|
|
|
|
|
for (auto &iter : send_varname_to_ctx_) {
|
|
|
|
|
send_varname_to_queue_[iter.first] =
|
|
|
|
@ -214,11 +218,20 @@ void Communicator::Send(const std::string &var_name,
|
|
|
|
|
// push var into send queue by var_name
|
|
|
|
|
auto *grad_var = scope.FindVar(var_name);
|
|
|
|
|
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
|
|
|
|
|
auto tmp_grad_var = std::make_shared<Variable>();
|
|
|
|
|
framework::CopyVariable(*grad_var, tmp_grad_var.get());
|
|
|
|
|
auto &queue = send_varname_to_queue_.at(var_name);
|
|
|
|
|
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
|
|
|
|
|
queue->Push(tmp_grad_var);
|
|
|
|
|
if (grad_var->IsType<framework::SelectedRows>() &&
|
|
|
|
|
!FLAGS_communicator_merge_sparse_grad) {
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
|
|
|
|
if (!FLAGS_communicator_fake_rpc) {
|
|
|
|
|
send_functor(ctx, scope, true);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto tmp_grad_var = std::make_shared<Variable>();
|
|
|
|
|
framework::CopyVariable(*grad_var, tmp_grad_var.get());
|
|
|
|
|
auto &queue = send_varname_to_queue_.at(var_name);
|
|
|
|
|
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
|
|
|
|
|
queue->Push(tmp_grad_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::Init(const paddle::framework::ProgramDesc &program,
|
|
|
|
|