|
|
|
@ -14,6 +14,9 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
|
|
|
|
|
|
|
|
|
#include <chrono> // NOLINT
|
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/selected_rows.h"
|
|
|
|
|
#include "paddle/fluid/framework/tensor_util.h"
|
|
|
|
|
#include "paddle/fluid/framework/variable_helper.h"
|
|
|
|
@ -28,6 +31,7 @@ namespace distributed {
|
|
|
|
|
static inline void MergeVars(const std::string &var_name,
|
|
|
|
|
const std::vector<std::shared_ptr<Variable>> &vars,
|
|
|
|
|
Scope *scope) {
|
|
|
|
|
VLOG(3) << "merge " << vars.size() << " vars " << var_name << " to one";
|
|
|
|
|
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
|
|
|
|
|
auto cpu_place = platform::CPUPlace();
|
|
|
|
|
auto &var0 = vars[0];
|
|
|
|
@ -67,29 +71,32 @@ std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
|
|
|
|
|
std::once_flag Communicator::init_flag_;
|
|
|
|
|
|
|
|
|
|
void Communicator::SendThread() {
|
|
|
|
|
VLOG("SendThread start!");
|
|
|
|
|
while (running_) {
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
|
task_futures.reserve(send_varname_to_ctx_.size());
|
|
|
|
|
for (auto &iter : send_varname_to_queue_) {
|
|
|
|
|
auto send_task = [this, &iter] {
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
|
auto &var_queue = iter.second;
|
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
|
// TODO(qiao): need to be configurable
|
|
|
|
|
const size_t max_merge_var_num = 20;
|
|
|
|
|
size_t merged_var_num = 0;
|
|
|
|
|
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
MergeVars(var_name, vars, send_scope_.get());
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
|
|
|
|
send_functor(ctx, *send_scope_, true);
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(
|
|
|
|
|
send_threadpool_->enqueue(std::move(send_task)));
|
|
|
|
|
auto &var_name = iter.first;
|
|
|
|
|
auto &var_queue = iter.second;
|
|
|
|
|
if (var_queue->NotEmpty()) { // will block if queue is empty
|
|
|
|
|
auto send_task = [this, &var_name, &var_queue] {
|
|
|
|
|
VLOG(3) << "merge var " << var_name << " and send";
|
|
|
|
|
std::vector<std::shared_ptr<Variable>> vars;
|
|
|
|
|
// TODO(qiao): need to be configurable
|
|
|
|
|
const size_t max_merge_var_num = 20;
|
|
|
|
|
size_t merged_var_num = 0;
|
|
|
|
|
while (var_queue->Size() > 0 && merged_var_num < max_merge_var_num) {
|
|
|
|
|
vars.push_back(var_queue->Pop());
|
|
|
|
|
merged_var_num++;
|
|
|
|
|
}
|
|
|
|
|
MergeVars(var_name, vars, send_scope_.get());
|
|
|
|
|
auto send_functor = distributed::ParameterSend<float>();
|
|
|
|
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
|
|
|
|
send_functor(ctx, *send_scope_, true);
|
|
|
|
|
};
|
|
|
|
|
task_futures.emplace_back(
|
|
|
|
|
send_threadpool_->enqueue(std::move(send_task)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &task_f : task_futures) {
|
|
|
|
|
task_f.wait();
|
|
|
|
@ -98,6 +105,7 @@ void Communicator::SendThread() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Communicator::RecvThread() {
|
|
|
|
|
VLOG(3) << "RecvThread start!";
|
|
|
|
|
while (running_) {
|
|
|
|
|
// parallel run recv graph
|
|
|
|
|
std::vector<std::future<void>> task_futures;
|
|
|
|
@ -115,6 +123,8 @@ void Communicator::RecvThread() {
|
|
|
|
|
for (auto &task : task_futures) {
|
|
|
|
|
task.wait();
|
|
|
|
|
}
|
|
|
|
|
// TODO(qiao) need to be configuable
|
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|