Merge pull request #16172 from jacquesqiao/add-async-ssa-graph-executor-communicator
Add async ssa graph executor communicatorrevert-16555-model_data_cryption_link_all_lib
commit
21622ca30b
@ -0,0 +1,203 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/variable_helper.h"
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||||
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
|
||||||
|
Scope *scope) {
|
||||||
|
VLOG(3) << "NewTempScopeAndInitVars";
|
||||||
|
Scope &local_scope = scope->NewScope();
|
||||||
|
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
|
||||||
|
&local_scope;
|
||||||
|
|
||||||
|
for (auto &info : var_infos) {
|
||||||
|
if (scope->FindVar(info.name_) != nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (info.persistable_) { // Persistable
|
||||||
|
InitializeVariable(scope->Var(info.name_), info.type_);
|
||||||
|
} else {
|
||||||
|
InitializeVariable(local_scope.Var(info.name_), info.type_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get RpcContext and remote send and recv op
|
||||||
|
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
||||||
|
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||||
|
using RpcCtxMap = operators::distributed::RpcCtxMap;
|
||||||
|
VLOG(3) << "ProcessGraph";
|
||||||
|
RpcCtxMap send_varname_to_ctx;
|
||||||
|
RpcCtxMap recv_varname_to_ctx;
|
||||||
|
for (auto i = 0; i < graphs.size(); ++i) {
|
||||||
|
std::vector<ir::Node *> nodes_to_delete;
|
||||||
|
for (auto &node : graphs[i]->Nodes()) {
|
||||||
|
VLOG(3) << "node name " << node->Name();
|
||||||
|
if (node && node->IsOp()) {
|
||||||
|
if (node->Name() == "send") {
|
||||||
|
auto send_var_name = node->Op()->Input("X")[0];
|
||||||
|
auto send_varnames = boost::get<std::vector<std::string>>(
|
||||||
|
node->Op()->GetNullableAttr("send_varnames"));
|
||||||
|
auto epmap = boost::get<std::vector<std::string>>(
|
||||||
|
node->Op()->GetNullableAttr("epmap"));
|
||||||
|
auto height_section = boost::get<std::vector<int64_t>>(
|
||||||
|
node->Op()->GetNullableAttr("sections"));
|
||||||
|
send_varname_to_ctx[send_var_name] =
|
||||||
|
operators::distributed::RpcContext(send_var_name, send_varnames,
|
||||||
|
epmap, height_section);
|
||||||
|
VLOG(3) << "find and init an send op: "
|
||||||
|
<< send_varname_to_ctx[send_var_name];
|
||||||
|
} else if (node->Name() == "recv") {
|
||||||
|
auto recv_var_name = node->Op()->Output("Out")[0];
|
||||||
|
auto recv_varnames = boost::get<std::vector<std::string>>(
|
||||||
|
node->Op()->GetNullableAttr("recv_varnames"));
|
||||||
|
auto epmap = boost::get<std::vector<std::string>>(
|
||||||
|
node->Op()->GetNullableAttr("epmap"));
|
||||||
|
recv_varname_to_ctx[recv_var_name] =
|
||||||
|
operators::distributed::RpcContext(recv_var_name, recv_varnames,
|
||||||
|
epmap, {});
|
||||||
|
nodes_to_delete.push_back(node);
|
||||||
|
VLOG(3) << "find and remove an recv op: "
|
||||||
|
<< recv_varname_to_ctx[recv_var_name];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// init communicator here
|
||||||
|
if (send_varname_to_ctx.size() > 0) {
|
||||||
|
VLOG(3) << "this is distribute mode, will use communicator";
|
||||||
|
operators::distributed::Communicator::Init(send_varname_to_ctx,
|
||||||
|
recv_varname_to_ctx, scope);
|
||||||
|
operators::distributed::Communicator::GetInstance()->Start();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
||||||
|
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
|
||||||
|
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
|
||||||
|
: strategy_(std::move(strategy)),
|
||||||
|
local_scopes_(std::move(local_scopes)),
|
||||||
|
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
|
||||||
|
places_(std::move(places)),
|
||||||
|
graphs_(std::move(graphs)) {
|
||||||
|
VLOG(3) << "build AsyncSSAGraphExecutor";
|
||||||
|
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
|
||||||
|
|
||||||
|
// set the correct size of thread pool to each device.
|
||||||
|
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
|
||||||
|
? 1UL
|
||||||
|
: strategy_.num_threads_ / places_.size();
|
||||||
|
VLOG(1) << "set num_threads: " << strategy_.num_threads_
|
||||||
|
<< " to run the operators of the graph on each device.";
|
||||||
|
for (size_t i = 0; i < places.size(); ++i) {
|
||||||
|
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
|
||||||
|
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &node : graphs_[0]->Nodes()) {
|
||||||
|
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
|
||||||
|
var_infos_.emplace_back();
|
||||||
|
var_infos_.back().name_ = node->Var()->Name();
|
||||||
|
var_infos_.back().type_ = node->Var()->GetType();
|
||||||
|
var_infos_.back().persistable_ = node->Var()->Persistable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto *scope : local_scopes_) {
|
||||||
|
NewTempScopeAndInitVars(var_infos_, scope);
|
||||||
|
}
|
||||||
|
ProcessGraph(graphs_, local_scopes_[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
|
||||||
|
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
|
||||||
|
for (size_t i = 1; i < places_.size(); ++i) {
|
||||||
|
auto call = [this, i]() -> void {
|
||||||
|
VLOG(3) << "start off python thread " << i;
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
executors_[i]->Run({});
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
exception_holder_.Catch(std::current_exception());
|
||||||
|
VLOG(3) << "get exception type = " << exception_holder_.Type();
|
||||||
|
}
|
||||||
|
VLOG(3) << "thread " << i << " exited!";
|
||||||
|
};
|
||||||
|
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsyncSSAGraphExecutor::HandleException() {
|
||||||
|
if (exception_holder_.IsCaught()) {
|
||||||
|
for (auto &f : run_futures_) {
|
||||||
|
VLOG(3) << "wait future";
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
VLOG(3) << "caught exception " << exception_holder_.Type()
|
||||||
|
<< ", rethrow it";
|
||||||
|
run_futures_.clear();
|
||||||
|
exception_holder_.ReThrow();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FeedFetchList AsyncSSAGraphExecutor::Run(
|
||||||
|
const std::vector<std::string> &fetch_tensors) {
|
||||||
|
// init once
|
||||||
|
if (run_futures_.size() == 0 && places_.size() > 1) {
|
||||||
|
exception_holder_.Clear();
|
||||||
|
StartOffPythonTrainLoop();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (places_.size() == 1) {
|
||||||
|
exception_holder_.Clear();
|
||||||
|
} else {
|
||||||
|
HandleException();
|
||||||
|
}
|
||||||
|
|
||||||
|
FeedFetchList fetch_data;
|
||||||
|
fetch_data.reserve(fetch_tensors.size());
|
||||||
|
|
||||||
|
try {
|
||||||
|
fetch_data = executors_[0]->Run(fetch_tensors);
|
||||||
|
} catch (...) {
|
||||||
|
exception_holder_.Catch(std::current_exception());
|
||||||
|
}
|
||||||
|
|
||||||
|
HandleException();
|
||||||
|
|
||||||
|
FeedFetchList ret;
|
||||||
|
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
|
||||||
|
std::vector<const LoDTensor *> lodtensor_ptrs;
|
||||||
|
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
|
||||||
|
ret.emplace_back();
|
||||||
|
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,65 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ThreadPool.h"
|
||||||
|
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
struct VarInfo {
|
||||||
|
std::string name_;
|
||||||
|
proto::VarType::Type type_;
|
||||||
|
bool persistable_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
|
||||||
|
public:
|
||||||
|
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
|
||||||
|
const std::vector<Scope *> &local_scopes,
|
||||||
|
const std::vector<platform::Place> &places,
|
||||||
|
std::vector<ir::Graph *> graphs);
|
||||||
|
~AsyncSSAGraphExecutor() final = default;
|
||||||
|
const ir::Graph &Graph() const override { return *graphs_[0]; }
|
||||||
|
|
||||||
|
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void StartOffPythonTrainLoop();
|
||||||
|
void HandleException();
|
||||||
|
|
||||||
|
private:
|
||||||
|
ExecutionStrategy strategy_;
|
||||||
|
std::vector<Scope *> local_scopes_;
|
||||||
|
std::unique_ptr<::ThreadPool> pool_{nullptr};
|
||||||
|
std::vector<platform::Place> places_;
|
||||||
|
std::vector<ir::Graph *> graphs_;
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
|
||||||
|
ExceptionHolder exception_holder_;
|
||||||
|
std::vector<std::future<void>> run_futures_;
|
||||||
|
std::vector<VarInfo> var_infos_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,213 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <chrono> // NOLINT
|
||||||
|
#include <thread> // NOLINT
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/selected_rows.h"
|
||||||
|
#include "paddle/fluid/framework/tensor_util.h"
|
||||||
|
#include "paddle/fluid/framework/variable_helper.h"
|
||||||
|
#include "paddle/fluid/operators/distributed/parameter_recv.h"
|
||||||
|
#include "paddle/fluid/operators/distributed/parameter_send.h"
|
||||||
|
|
||||||
|
DEFINE_bool(communicator_independent_recv_thread, true,
|
||||||
|
"use an independent to recv vars from parameter server");
|
||||||
|
DEFINE_int32(communicator_send_queue_size, 20,
|
||||||
|
"queue size to recv gradient before send");
|
||||||
|
DEFINE_int32(communicator_max_send_grad_num_before_recv, 20,
|
||||||
|
"max grad num to send before recv parameters");
|
||||||
|
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
|
||||||
|
DEFINE_int32(communicator_max_merge_var_num, 20,
|
||||||
|
"max var num to merge and send");
|
||||||
|
DEFINE_bool(communicator_fake_rpc, false,
|
||||||
|
"fake mode does not really send any thing");
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace distributed {
|
||||||
|
|
||||||
|
inline double GetCurrentUS() {
|
||||||
|
struct timeval time;
|
||||||
|
gettimeofday(&time, NULL);
|
||||||
|
return 1e+6 * time.tv_sec + time.tv_usec;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
|
||||||
|
std::once_flag Communicator::init_flag_;
|
||||||
|
|
||||||
|
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
|
||||||
|
const RpcCtxMap &recv_varname_to_ctx,
|
||||||
|
Scope *recv_scope)
|
||||||
|
: send_varname_to_ctx_(send_varname_to_ctx),
|
||||||
|
recv_varname_to_ctx_(recv_varname_to_ctx),
|
||||||
|
recv_scope_(recv_scope) {
|
||||||
|
// get all send information from graph, build vars_to_send
|
||||||
|
VLOG(0) << "communicator_independent_recv_thread: "
|
||||||
|
<< FLAGS_communicator_independent_recv_thread;
|
||||||
|
VLOG(0) << "communicator_send_queue_size: "
|
||||||
|
<< FLAGS_communicator_send_queue_size;
|
||||||
|
VLOG(0) << "communicator_max_send_grad_num_before_recv: "
|
||||||
|
<< FLAGS_communicator_max_send_grad_num_before_recv;
|
||||||
|
VLOG(0) << "communicator_thread_pool_size: "
|
||||||
|
<< FLAGS_communicator_thread_pool_size;
|
||||||
|
VLOG(0) << "communicator_max_merge_var_num: "
|
||||||
|
<< FLAGS_communicator_max_merge_var_num;
|
||||||
|
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
|
||||||
|
send_scope_.reset(new Scope());
|
||||||
|
for (auto &iter : send_varname_to_ctx_) {
|
||||||
|
send_varname_to_queue_[iter.first] =
|
||||||
|
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
|
||||||
|
FLAGS_communicator_send_queue_size);
|
||||||
|
}
|
||||||
|
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
|
||||||
|
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
Communicator::~Communicator() {
|
||||||
|
VLOG(3) << "~Communicator";
|
||||||
|
running_ = false;
|
||||||
|
if (send_thread_) send_thread_->join();
|
||||||
|
if (recv_thread_) recv_thread_->join();
|
||||||
|
VLOG(3) << "~Communicator done";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Communicator::SendThread() {
|
||||||
|
VLOG(3) << "SendThread start!";
|
||||||
|
while (running_) {
|
||||||
|
std::vector<std::future<void>> task_futures;
|
||||||
|
task_futures.reserve(send_varname_to_ctx_.size());
|
||||||
|
VLOG(3) << "run send graph";
|
||||||
|
auto before_run_send_graph = GetCurrentUS();
|
||||||
|
for (auto &iter : send_varname_to_queue_) {
|
||||||
|
auto &var_name = iter.first;
|
||||||
|
auto &var_queue = iter.second;
|
||||||
|
if (var_queue->Size() > 0) {
|
||||||
|
auto send_task = [this, &var_name, &var_queue] {
|
||||||
|
VLOG(3) << var_name << " merge and send";
|
||||||
|
std::vector<std::shared_ptr<Variable>> vars;
|
||||||
|
size_t merged_var_num = 0;
|
||||||
|
while (var_queue->Size() > 0 &&
|
||||||
|
merged_var_num < FLAGS_communicator_max_merge_var_num) {
|
||||||
|
vars.push_back(var_queue->Pop());
|
||||||
|
// only count the send number of the first var
|
||||||
|
if (var_name == send_varname_to_queue_.begin()->first) {
|
||||||
|
grad_num_.fetch_add(1, std::memory_order_relaxed);
|
||||||
|
}
|
||||||
|
merged_var_num++;
|
||||||
|
}
|
||||||
|
auto before_merge = GetCurrentUS();
|
||||||
|
MergeVars(var_name, vars, send_scope_.get());
|
||||||
|
auto after_merge = GetCurrentUS();
|
||||||
|
VLOG(3) << "merge " << var_name << " use time "
|
||||||
|
<< after_merge - before_merge;
|
||||||
|
auto send_functor = distributed::ParameterSend<float>();
|
||||||
|
auto &ctx = send_varname_to_ctx_.at(var_name);
|
||||||
|
if (!FLAGS_communicator_fake_rpc) {
|
||||||
|
send_functor(ctx, *send_scope_, true);
|
||||||
|
}
|
||||||
|
auto after_send = GetCurrentUS();
|
||||||
|
VLOG(3) << "send " << var_name << " use time "
|
||||||
|
<< after_send - after_merge;
|
||||||
|
};
|
||||||
|
task_futures.emplace_back(
|
||||||
|
send_threadpool_->enqueue(std::move(send_task)));
|
||||||
|
} else {
|
||||||
|
VLOG(3) << var_name << " queue empty";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto &task_f : task_futures) {
|
||||||
|
task_f.wait();
|
||||||
|
}
|
||||||
|
auto after_run_send_graph = GetCurrentUS();
|
||||||
|
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
|
||||||
|
if (send_graph_use_time > 100) {
|
||||||
|
VLOG(1) << "run send graph use time "
|
||||||
|
<< after_run_send_graph - before_run_send_graph;
|
||||||
|
}
|
||||||
|
if (!FLAGS_communicator_independent_recv_thread) {
|
||||||
|
RecvAll();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Communicator::RecvAll() {
|
||||||
|
VLOG(3) << "parallel run recv graph";
|
||||||
|
auto before_send = GetCurrentUS();
|
||||||
|
std::vector<std::future<void>> task_futures;
|
||||||
|
task_futures.reserve(recv_varname_to_ctx_.size());
|
||||||
|
for (auto &iter : recv_varname_to_ctx_) {
|
||||||
|
auto recv_task = [this, &iter] {
|
||||||
|
auto &var_name = iter.first;
|
||||||
|
VLOG(3) << "recv var " << var_name;
|
||||||
|
auto recv_functor = distributed::ParameterRecv<float>();
|
||||||
|
if (!FLAGS_communicator_fake_rpc) {
|
||||||
|
recv_functor(iter.second, *recv_scope_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
|
||||||
|
}
|
||||||
|
for (auto &task : task_futures) {
|
||||||
|
task.wait();
|
||||||
|
}
|
||||||
|
auto after_recv = GetCurrentUS();
|
||||||
|
VLOG(1) << "run recv graph use time " << after_recv - before_send;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Communicator::RecvThread() {
|
||||||
|
VLOG(3) << "RecvThread start!";
|
||||||
|
while (running_) {
|
||||||
|
auto grad_num = grad_num_.load();
|
||||||
|
if (grad_num > FLAGS_communicator_max_send_grad_num_before_recv) {
|
||||||
|
VLOG(1) << "current grad num " << grad_num;
|
||||||
|
RecvAll();
|
||||||
|
grad_num_.store(0);
|
||||||
|
} else {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Communicator::Send(const std::string &var_name,
|
||||||
|
const framework::Scope &scope) {
|
||||||
|
VLOG(3) << "communicator send " << var_name;
|
||||||
|
// push var into send queue by var_name
|
||||||
|
auto *grad_var = scope.FindVar(var_name);
|
||||||
|
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
|
||||||
|
auto tmp_grad_var = std::make_shared<Variable>();
|
||||||
|
framework::CopyVariable(*grad_var, tmp_grad_var.get());
|
||||||
|
auto &queue = send_varname_to_queue_.at(var_name);
|
||||||
|
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
|
||||||
|
queue->Push(tmp_grad_var);
|
||||||
|
}
|
||||||
|
|
||||||
|
Communicator *Communicator::GetInstance() { return communicator_.get(); }
|
||||||
|
|
||||||
|
void Communicator::Start() {
|
||||||
|
running_ = true;
|
||||||
|
// start send and recv thread
|
||||||
|
send_thread_.reset(
|
||||||
|
new std::thread(std::bind(&Communicator::SendThread, this)));
|
||||||
|
if (FLAGS_communicator_independent_recv_thread) {
|
||||||
|
recv_thread_.reset(
|
||||||
|
new std::thread(std::bind(&Communicator::RecvThread, this)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,219 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <deque>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ThreadPool.h>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/scope.h"
|
||||||
|
#include "paddle/fluid/framework/variable.h"
|
||||||
|
#include "paddle/fluid/operators/distributed/rpc_common.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
||||||
|
#include "paddle/fluid/platform/device_context.h"
|
||||||
|
#include "paddle/fluid/platform/enforce.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace distributed {
|
||||||
|
|
||||||
|
using Scope = framework::Scope;
|
||||||
|
using Variable = framework::Variable;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BlockingQueue {
|
||||||
|
public:
|
||||||
|
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
|
||||||
|
PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Push(const T& elem) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
|
||||||
|
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
|
||||||
|
queue_.push_back(elem);
|
||||||
|
}
|
||||||
|
cv_.notify_one();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Push(T&& elem) {
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
|
||||||
|
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
|
||||||
|
queue_.emplace_back(std::move(elem));
|
||||||
|
}
|
||||||
|
cv_.notify_one();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
T Pop() {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
cv_.wait(lock, [=] { return !queue_.empty(); });
|
||||||
|
T rc(std::move(queue_.front()));
|
||||||
|
queue_.pop_front();
|
||||||
|
cv_.notify_one();
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Cap() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return capacity_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t Size() const {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
return queue_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const size_t capacity_;
|
||||||
|
std::deque<T> queue_;
|
||||||
|
|
||||||
|
mutable std::mutex mutex_;
|
||||||
|
std::condition_variable cv_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||||
|
|
||||||
|
inline void MergeVars(const std::string& var_name,
|
||||||
|
const std::vector<std::shared_ptr<Variable>>& vars,
|
||||||
|
Scope* scope) {
|
||||||
|
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
|
||||||
|
auto cpu_place = platform::CPUPlace();
|
||||||
|
auto& var0 = vars[0];
|
||||||
|
auto* out_var = scope->Var(var_name);
|
||||||
|
if (var0->IsType<framework::LoDTensor>()) {
|
||||||
|
auto dims = var0->Get<framework::LoDTensor>().dims();
|
||||||
|
VLOG(3) << "merge " << var_name << " LoDTensor " << dims;
|
||||||
|
|
||||||
|
// init output tensor
|
||||||
|
auto* out_t = out_var->GetMutable<framework::LoDTensor>();
|
||||||
|
out_t->mutable_data<float>(dims, cpu_place);
|
||||||
|
|
||||||
|
// check the input dims
|
||||||
|
for (auto& var : vars) {
|
||||||
|
auto& var_t = var->Get<framework::LoDTensor>();
|
||||||
|
PADDLE_ENFORCE_EQ(var_t.dims(), dims, "should have the same dims");
|
||||||
|
}
|
||||||
|
|
||||||
|
// set output tensor to 0.
|
||||||
|
auto cpu_ctx = paddle::platform::CPUDeviceContext();
|
||||||
|
math::SetConstant<paddle::platform::CPUDeviceContext, float>
|
||||||
|
constant_functor;
|
||||||
|
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
|
||||||
|
|
||||||
|
// sum all vars to out
|
||||||
|
auto result = EigenVector<float>::Flatten(*out_t);
|
||||||
|
for (auto& var : vars) {
|
||||||
|
auto& in_t = var->Get<framework::LoDTensor>();
|
||||||
|
auto in = EigenVector<float>::Flatten(in_t);
|
||||||
|
result.device(*cpu_ctx.eigen_device()) = result + in;
|
||||||
|
}
|
||||||
|
} else if (var0->IsType<framework::SelectedRows>()) {
|
||||||
|
auto& slr0 = var0->Get<framework::SelectedRows>();
|
||||||
|
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
|
||||||
|
out_slr->mutable_rows()->clear();
|
||||||
|
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
|
||||||
|
std::vector<const paddle::framework::SelectedRows*> inputs;
|
||||||
|
inputs.reserve(vars.size());
|
||||||
|
for (auto& var : vars) {
|
||||||
|
inputs.push_back(&var->Get<framework::SelectedRows>());
|
||||||
|
}
|
||||||
|
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
|
||||||
|
merge_add;
|
||||||
|
auto dev_ctx = paddle::platform::CPUDeviceContext();
|
||||||
|
merge_add(dev_ctx, inputs, out_slr, false);
|
||||||
|
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
|
||||||
|
<< " dims: " << slr0.value().dims();
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW("unsupported var type!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
|
||||||
|
|
||||||
|
class Communicator {
|
||||||
|
public:
|
||||||
|
Communicator(const RpcCtxMap& send_varname_to_ctx,
|
||||||
|
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope);
|
||||||
|
|
||||||
|
~Communicator();
|
||||||
|
|
||||||
|
void Start();
|
||||||
|
|
||||||
|
// send grad
|
||||||
|
void Send(const std::string& var_name, const framework::Scope& scope);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// recv all parameter
|
||||||
|
void RecvAll();
|
||||||
|
void SendThread();
|
||||||
|
void RecvThread();
|
||||||
|
|
||||||
|
bool running_ = false;
|
||||||
|
std::unordered_map<std::string,
|
||||||
|
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
|
||||||
|
send_varname_to_queue_;
|
||||||
|
RpcCtxMap send_varname_to_ctx_;
|
||||||
|
RpcCtxMap recv_varname_to_ctx_;
|
||||||
|
std::unique_ptr<std::thread> send_thread_;
|
||||||
|
std::unique_ptr<std::thread> recv_thread_;
|
||||||
|
Scope* recv_scope_; // should be global scope
|
||||||
|
std::unique_ptr<Scope> send_scope_; // an independent scope
|
||||||
|
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
|
||||||
|
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
|
||||||
|
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
|
||||||
|
|
||||||
|
// the following code is for initialize the commnunicator
|
||||||
|
public:
|
||||||
|
static void Init(const RpcCtxMap& send_varname_to_ctx,
|
||||||
|
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
|
||||||
|
InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Communicator* GetInstance();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Init is called by GetInstance.
|
||||||
|
static void InitImpl(const RpcCtxMap& send_varname_to_ctx,
|
||||||
|
const RpcCtxMap& recv_varname_to_ctx,
|
||||||
|
Scope* recv_scope) {
|
||||||
|
if (communicator_ == nullptr) {
|
||||||
|
communicator_.reset(new Communicator(send_varname_to_ctx,
|
||||||
|
recv_varname_to_ctx, recv_scope));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::once_flag init_flag_;
|
||||||
|
static std::unique_ptr<Communicator> communicator_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,110 @@
|
|||||||
|
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/distributed/communicator.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
namespace distributed {
|
||||||
|
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
using SelectedRows = framework::SelectedRows;
|
||||||
|
|
||||||
|
TEST(communicator, merge_lod_tensors) {
|
||||||
|
auto cpu_place = platform::CPUPlace();
|
||||||
|
auto dims = framework::make_ddim({2, 3});
|
||||||
|
std::vector<std::shared_ptr<framework::Variable>> in_vars;
|
||||||
|
float out_value = 0;
|
||||||
|
for (auto i = 0; i < 10; ++i) {
|
||||||
|
auto var = std::make_shared<Variable>();
|
||||||
|
in_vars.emplace_back(var);
|
||||||
|
auto *tensor = var->GetMutable<LoDTensor>();
|
||||||
|
auto *data = tensor->mutable_data<float>(dims, cpu_place);
|
||||||
|
for (auto j = 0; j < tensor->numel(); ++j) {
|
||||||
|
data[j] = static_cast<float>(i);
|
||||||
|
}
|
||||||
|
out_value += static_cast<float>(i);
|
||||||
|
}
|
||||||
|
const std::string out_name = "Out";
|
||||||
|
std::unique_ptr<framework::Scope> scope;
|
||||||
|
scope.reset(new framework::Scope());
|
||||||
|
scope->Var(out_name);
|
||||||
|
for (auto i = 0; i < 10; ++i) {
|
||||||
|
MergeVars(out_name, in_vars, scope.get());
|
||||||
|
}
|
||||||
|
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
|
||||||
|
auto *out_data = out_tensor.data<float>();
|
||||||
|
ASSERT_EQ(out_tensor.dims(), dims);
|
||||||
|
for (auto i = 0; i < out_tensor.numel(); ++i) {
|
||||||
|
ASSERT_EQ(out_data[i], out_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(communicator, merge_selected_rows) {
|
||||||
|
auto cpu_place = platform::CPUPlace();
|
||||||
|
int64_t width = 10;
|
||||||
|
std::vector<std::shared_ptr<framework::Variable>> in_vars;
|
||||||
|
const int64_t height = 100;
|
||||||
|
for (auto i = 0; i < 10; ++i) {
|
||||||
|
std::vector<int64_t> rows;
|
||||||
|
for (auto k = 0; k <= i; ++k) {
|
||||||
|
rows.push_back(k);
|
||||||
|
}
|
||||||
|
auto var = std::make_shared<Variable>();
|
||||||
|
in_vars.emplace_back(var);
|
||||||
|
auto *slr = var->GetMutable<SelectedRows>();
|
||||||
|
slr->set_height(height);
|
||||||
|
slr->set_rows(rows);
|
||||||
|
auto dims =
|
||||||
|
framework::make_ddim({static_cast<int64_t>(rows.size()), width});
|
||||||
|
auto *data = slr->mutable_value()->mutable_data<float>(dims, cpu_place);
|
||||||
|
for (auto i = 0; i < rows.size(); ++i) {
|
||||||
|
for (auto j = 0; j < width; ++j) {
|
||||||
|
data[i * width + j] = static_cast<float>(rows[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const std::string out_name = "Out";
|
||||||
|
std::unique_ptr<framework::Scope> scope;
|
||||||
|
scope.reset(new framework::Scope());
|
||||||
|
scope->Var(out_name);
|
||||||
|
for (auto i = 0; i < 10; ++i) {
|
||||||
|
MergeVars(out_name, in_vars, scope.get());
|
||||||
|
}
|
||||||
|
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
|
||||||
|
auto &out_t = out_slr.value();
|
||||||
|
auto *out_data = out_t.data<float>();
|
||||||
|
ASSERT_EQ(out_t.dims(), framework::make_ddim({10, width}));
|
||||||
|
std::vector<float> out_values;
|
||||||
|
out_values.reserve(10);
|
||||||
|
for (auto i = 0; i < 10; ++i) {
|
||||||
|
out_values.push_back(static_cast<float>(i * (10 - i)));
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < out_slr.rows().size(); ++i) {
|
||||||
|
ASSERT_EQ(out_slr.rows()[i], i);
|
||||||
|
for (auto j = 0; j < width; ++j) {
|
||||||
|
ASSERT_EQ(out_data[i * width + j], out_values[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace distributed
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue