|
|
|
@ -25,9 +25,9 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace distributed {
|
|
|
|
|
|
|
|
|
|
static void MergeVars(const std::string &var_name,
|
|
|
|
|
const std::vector<std::shared_ptr<Variable>> &vars,
|
|
|
|
|
Scope *scope) {
|
|
|
|
|
static inline void MergeVars(const std::string &var_name,
|
|
|
|
|
const std::vector<std::shared_ptr<Variable>> &vars,
|
|
|
|
|
Scope *scope) {
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
|
|
|
|
|
auto cpu_place = platform::CPUPlace();
|
|
|
|
|
auto &var0 = vars[0];
|
|
|
|
@ -62,31 +62,53 @@ static void MergeVars(const std::string &var_name,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::SendThread() {
|
|
|
|
|
for (auto &iter : send_varname_to_queue_) {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
|
auto &var_queue = iter.second;
|
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
|
const size_t max_merge_var_num = 20;
|
|
|
|
|
size_t merged_var_num = 0;
|
|
|
|
|
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
while (running_) {
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
|
task_futures.reserve(send_varname_to_ctx_.size());
|
|
|
|
|
for (auto &iter : send_varname_to_queue_) {
|
|
|
|
|
auto send_task = [this, &iter] {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
|
auto &var_queue = iter.second;
|
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
|
const size_t max_merge_var_num = 20;
|
|
|
|
|
size_t merged_var_num = 0;
|
|
|
|
|
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
MergeVars(var_name, vars, send_scope_.get());
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
|
|
|
|
send_functor(ctx, *send_scope_, true);
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(
|
|
|
|
|
send_threadpool_->enqueue(std::move(send_task)));
|
|
|
|
|
}
|
|
|
|
|
for (auto &task_f : task_futures) {
|
|
|
|
|
task_f.wait();
|
|
|
|
|
}
|
|
|
|
|
MergeVars(var_name, vars, send_scope_.get());
|
|
|
|
|
// auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx,
|
|
|
|
|
// send_scope_, true);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvThread() {
|
|
|
|
|
// parallel run recv graph
|
|
|
|
|
for (auto &iter : recv_varname_to_ctx_) {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(3) << "recv var " << iter.first;
|
|
|
|
|
// auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
|
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_);
|
|
|
|
|
while (running_) {
|
|
|
|
|
// parallel run recv graph
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
|
task_futures.reserve(recv_varname_to_ctx_.size());
|
|
|
|
|
for (auto &iter : recv_varname_to_ctx_) {
|
|
|
|
|
auto recv_task = [this, &iter] {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(3) << "recv var " << var_name;
|
|
|
|
|
auto recv_functor = distributed::ParameterRecv<float>();
|
|
|
|
|
recv_functor(iter.second, *recv_scope_);
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(
|
|
|
|
|
recv_threadpool_->enqueue(std::move(recv_task)));
|
|
|
|
|
}
|
|
|
|
|
for (auto &task : task_futures) {
|
|
|
|
|
task.wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::Start() {
|
|
|
|
|
running_ = true;
|
|
|
|
|
// start send and recv thread
|
|
|
|
|
send_thread_.reset(
|
|
|
|
|
new std::thread(std::bind(&Communicator::SendThread, this)));
|
|
|
|
|