Merge pull request #16172 from jacquesqiao/add-async-ssa-graph-executor-communicator

Add async ssa graph executor communicator
revert-16555-model_data_cryption_link_all_lib
乔龙飞 Qiao Longfei 6 years ago committed by GitHub
commit 21622ca30b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -196,7 +196,7 @@ endif()
target_link_libraries(executor while_op_helper executor_gc_helper) target_link_libraries(executor while_op_helper executor_gc_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy graph build_strategy
fast_threaded_ssa_graph_executor variable_helper) fast_threaded_ssa_graph_executor variable_helper)

@ -96,6 +96,12 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor) cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
set(ASYNC_SSA_GRAPH_EXECUTOR_DEPS threaded_ssa_graph_executor)
if(WITH_DISTRIBUTE)
list(APPEND ASYNC_SSA_GRAPH_EXECUTOR_DEPS communicator)
endif()
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS ${ASYNC_SSA_GRAPH_EXECUTOR_DEPS})
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory

@ -0,0 +1,203 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/communicator.h"
#endif
namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
}
#endif
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
: strategy_.num_threads_ / places_.size();
VLOG(1) << "set num_threads: " << strategy_.num_threads_
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
}
for (auto &node : graphs_[0]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos_.emplace_back();
var_infos_.back().name_ = node->Var()->Name();
var_infos_.back().type_ = node->Var()->GetType();
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void {
VLOG(3) << "start off python thread " << i;
try {
while (true) {
executors_[i]->Run({});
}
} catch (...) {
exception_holder_.Catch(std::current_exception());
VLOG(3) << "get exception type = " << exception_holder_.Type();
}
VLOG(3) << "thread " << i << " exited!";
};
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
}
}
void AsyncSSAGraphExecutor::HandleException() {
if (exception_holder_.IsCaught()) {
for (auto &f : run_futures_) {
VLOG(3) << "wait future";
f.wait();
}
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
run_futures_.clear();
exception_holder_.ReThrow();
}
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
if (places_.size() == 1) {
exception_holder_.Clear();
} else {
HandleException();
}
FeedFetchList fetch_data;
fetch_data.reserve(fetch_tensors.size());
try {
fetch_data = executors_[0]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
HandleException();
FeedFetchList ret;
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
return ret;
}
} // namespace details
} // namespace framework
} // namespace paddle

@ -0,0 +1,65 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
namespace paddle {
namespace framework {
namespace details {
struct VarInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
void StartOffPythonTrainLoop();
void HandleException();
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
std::vector<std::future<void>> run_futures_;
std::vector<VarInfo> var_infos_;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -184,8 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices. // Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) { void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr; ir::Pass *multi_devices_pass = nullptr;
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass"; if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(10)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else { } else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
@ -234,10 +238,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
#else #else
const bool use_cuda) const { const bool use_cuda) const {
#endif #endif
VLOG(3) << "apply all passes";
// Create a default one if not finalized by user. // Create a default one if not finalized by user.
CreatePassesFromStrategy(false); CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) { for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type();
if (IsMultiDevPass(pass->Type())) { if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces); pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places); pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
@ -293,6 +299,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
graph = pass->Apply(graph); graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
VLOG(3) << "All Passes Applied";
return graph; return graph;
} }

@ -97,6 +97,7 @@ struct BuildStrategy {
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model. // it's distributed model.
bool is_distribution_{false}; bool is_distribution_{false};
bool async_mode_{false};
int num_trainers_{1}; int num_trainers_{1};
int trainer_id_{0}; int trainer_id_{0};
std::vector<std::string> trainers_endpoints_; std::vector<std::string> trainers_endpoints_;

@ -14,6 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
@ -64,6 +67,21 @@ class ExceptionHolder {
ClearImpl(); ClearImpl();
} }
std::string Type() {
std::lock_guard<std::mutex> lock(mu_);
switch (type_) {
case kNone:
return "None";
case kEnforceNotMet: {
return "EnforceNotMet";
}
case kEOF: {
return "EOF";
}
}
return "unknown";
}
private: private:
void ClearImpl() { void ClearImpl() {
exception_.reset(); exception_.reset();

@ -31,6 +31,8 @@ struct ExecutionStrategy {
size_t num_iteration_per_drop_scope_{1}; size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
size_t num_iteration_per_run_{1}; // only use with async_ssa_graph_executor
// and pyreader with data queue
}; };
} // namespace details } // namespace details

@ -198,8 +198,22 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward)); static_cast<int>(OpRole::kBackward));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if (!is_bk_op) continue; if (!is_bk_op) continue;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
auto backward_vars = auto backward_vars =
@ -256,6 +270,8 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
break; break;
} }
VLOG(3) << "loss_scale: " << loss_scale;
if (loss_scale) { if (loss_scale) {
// TODO(paddle-dev): Why is there no input for this op_handle? // TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
@ -407,7 +423,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { size_t dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id)); local_scopes_[dev_id], places_[dev_id], dev_id));
@ -494,9 +510,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
} }
} }
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
const std::string &og, ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
int dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
@ -774,6 +789,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name); bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
} }
@ -781,6 +798,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id); sharded_var_device_.emplace(n->Name(), op_dev_id);
@ -961,6 +979,7 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
// collective gradient to each device
size_t cur_device_id = 0; size_t cur_device_id = 0;
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
@ -1049,3 +1068,5 @@ REGISTER_MULTI_DEVICES_PASS(
paddle::framework::details::AllReduceSSAGraphBuilder); paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass, REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder); paddle::framework::details::DistSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(async_multi_devices_pass,
paddle::framework::details::AsyncSSAGraphBuilder);

@ -56,7 +56,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool UseGPU() const; bool UseGPU() const;
bool NeedCollectiveForGrad(const std::string &grad_name, virtual bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const; std::vector<ir::Node *> ops) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
@ -70,10 +70,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto::VarType::Type dtype) const; proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; size_t dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
@ -115,6 +115,35 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual void InsertPostprocessOps(ir::Graph *result) const {} virtual void InsertPostprocessOps(ir::Graph *result) const {}
}; };
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const override {}
bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const {
return false;
}
bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
if (node->Op()->Type() == "recv") {
VLOG(1) << "set recv op do_not_run to true";
node->Op()->SetAttr("do_not_run", true);
node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
}
return false;
}
void InsertPostprocessOps(ir::Graph *result) const override {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected: protected:
int GetVarDeviceID(const std::string &varname) const; int GetVarDeviceID(const std::string &varname) const;

@ -31,11 +31,23 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
prepare_pool_(1), prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) { : nullptr) {
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graph_->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
PrepareOpDeps(); PrepareOpDeps();
CopyOpDeps(); CopyOpDeps();
} }
FeedFetchList ThreadedSSAGraphExecutor::Run( inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event( std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
@ -84,6 +96,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) { if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
@ -114,6 +128,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data; return fetch_data;
} }
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({});
}
return RunImpl(fetch_tensors);
}
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<FetchOpHandle *> *fetch_ops,

@ -23,7 +23,9 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "ThreadPool.h" // ThreadPool in thrird party
#include <ThreadPool.h> // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
@ -59,6 +61,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() final = default; ~ThreadedSSAGraphExecutor() final = default;
private: private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op); details::OpHandleBase *op);

@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include <memory>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
Graph* Pass::Apply(Graph* graph) const { Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty."); PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) { for (const std::string& attr : required_pass_attrs_) {

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h" #include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
@ -218,6 +219,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
} }
std::vector<ir::Graph *> graphs;
if (build_strategy.async_mode_) {
PADDLE_ENFORCE(!member_->use_cuda_,
"gpu mode does not support async_mode_ now!");
graphs.push_back(graph);
for (int i = 1; i < places.size(); ++i) {
auto *tmp_graph = new ir::Graph(graph->OriginProgram());
async_graphs_.emplace_back(tmp_graph);
graphs.push_back(tmp_graph);
}
}
// FIXME(Yancey1989): parallel graph mode get better performance // FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to // in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy. // choice the execution strategy.
@ -294,19 +307,46 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if (need_broadcast()) { if (need_broadcast()) {
BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_); BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
std::vector<ir::Graph *> async_graphs(places.size());
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode";
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
for (int i = 1; i < member_->places_.size(); ++i) {
graphs[i] =
build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
async_graphs[i] = graphs[i];
}
} else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name, graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->local_scopes_, member_->nranks_,
member_->use_cuda_, member_->nccl_ctxs_.get()); member_->use_cuda_, member_->nccl_ctxs_.get());
}
#else #else
if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode";
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_);
for (int i = 1; i < member_->places_.size(); ++i) {
graphs[i] = build_strategy.Apply(
graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, member_->use_cuda_);
async_graphs[i] = graphs[i];
}
} else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name, graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_, member_->local_scopes_, member_->nranks_,
member_->use_cuda_); member_->use_cuda_);
}
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
@ -317,6 +357,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
static_cast<size_t>(max_memory_size)); static_cast<size_t>(max_memory_size));
} }
async_graphs[0] = graph;
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
@ -344,7 +386,12 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
} }
if (build_strategy.enable_parallel_graph_) { if (build_strategy.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
} else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
// TODO(Yancey1989): Remove passing in the main_program when // TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr. // allreduce_seq_pass doesn't need it as the attr.
@ -356,21 +403,27 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
#endif #endif
} else { } else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->places_, graph));
} else { } else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->places_, graph));
} }
} }
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!build_strategy.async_mode_) {
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
} }
}
void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::BCastParamsToDevices(
const std::vector<std::string> &vars, int trainer_id) const { const std::vector<std::string> &vars, int trainer_id) const {
VLOG(3) << "BCastParamsToDevices";
// the initializing bcast, all vars would be bcast from device(0). // the initializing bcast, all vars would be bcast from device(0).
for (auto &var : vars) { for (auto &var : vars) {
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var); framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
@ -425,14 +478,22 @@ void ParallelExecutor::BCastParamsToDevices(
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix. auto copy_memory = [&] {
if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
t->Resize(dims); t->Resize(dims);
t->mutable_data(cpu, main_tensor.type()); t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t); paddle::framework::TensorCopy(main_tensor, cpu, t);
};
auto share_memory = [&] { t->ShareDataWith(main_tensor); };
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if (member_->build_strategy_.async_mode_) {
share_memory();
} else if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
copy_memory();
} else { } else {
t->ShareDataWith(main_tensor); share_memory();
} }
} }
} }

@ -81,6 +81,7 @@ class ParallelExecutor {
const BuildStrategy &build_strategy) const; const BuildStrategy &build_strategy) const;
ParallelExecutorPrivate *member_; ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<ncclUniqueId> local_nccl_id_; std::unique_ptr<ncclUniqueId> local_nccl_id_;
#endif #endif

@ -69,6 +69,9 @@ void ReaderBase::Start() {
ReaderBase::~ReaderBase() {} ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); } DecoratedReader::~DecoratedReader() {
VLOG(1) << "~DecoratedReader";
reader_->Shutdown();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
@ -77,7 +78,10 @@ class DecoratedReader : public ReaderBase,
~DecoratedReader(); ~DecoratedReader();
protected: protected:
void ShutdownImpl() override { reader_->Shutdown(); } void ShutdownImpl() override {
VLOG(1) << "ShutdownImpl";
reader_->Shutdown();
}
void StartImpl() override { reader_->Start(); } void StartImpl() override { reader_->Start(); }
@ -98,6 +102,8 @@ class ReaderHolder {
reader_ = reader_base; reader_ = reader_base;
} }
~ReaderHolder() { VLOG(1) << "~ReaderHolder"; }
const std::shared_ptr<ReaderBase>& Get() const { return reader_; } const std::shared_ptr<ReaderBase>& Get() const { return reader_; }
void ReadNext(std::vector<LoDTensor>* out) { void ReadNext(std::vector<LoDTensor>* out) {
@ -106,6 +112,7 @@ class ReaderHolder {
} }
void ResetAll() { void ResetAll() {
VLOG(1) << "ResetAll";
auto end_readers = reader_->GetEndPoints(); auto end_readers = reader_->GetEndPoints();
for (auto* reader : end_readers) { for (auto* reader : end_readers) {
reader->Shutdown(); reader->Shutdown();
@ -116,11 +123,13 @@ class ReaderHolder {
} }
void Shutdown() { void Shutdown() {
VLOG(1) << "Shutdown";
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Shutdown(); reader_->Shutdown();
} }
void Start() { void Start() {
VLOG(1) << "start";
PADDLE_ENFORCE_NOT_NULL(reader_); PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->Start(); reader_->Start();
} }

@ -59,6 +59,10 @@ Scope& Scope::NewScope() const {
return *child; return *child;
} }
std::unique_ptr<Scope> Scope::NewTmpScope() const {
return std::unique_ptr<Scope>(new Scope(this));
}
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
SCOPE_VARS_WRITER_LOCK SCOPE_VARS_WRITER_LOCK
return VarInternal(name); return VarInternal(name);

@ -52,6 +52,10 @@ class Scope {
/// Mark it to const because that new kid scope cannot change parent scope. /// Mark it to const because that new kid scope cannot change parent scope.
Scope& NewScope() const; Scope& NewScope() const;
/// Create a sub-scope for current scope but do not record it in the kids to
/// avoid performance problems.
std::unique_ptr<Scope> NewTmpScope() const;
/// Create a variable with given name if it doesn't exist. /// Create a variable with given name if it doesn't exist.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* Var(const std::string& name); Variable* Var(const std::string& name);

@ -57,5 +57,27 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
var_type); var_type);
} }
} }
void CopyVariable(const Variable &src_var, Variable *dst_var) {
// only support cpu now
auto cpu_place = platform::CPUPlace();
if (src_var.IsType<framework::LoDTensor>()) {
auto *tmp_grad_tensor = dst_var->GetMutable<framework::LoDTensor>();
auto &src_tensor = src_var.Get<framework::LoDTensor>();
tmp_grad_tensor->set_lod(src_tensor.lod());
framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor);
} else if (src_var.IsType<framework::SelectedRows>()) {
auto &src_slr = src_var.Get<framework::SelectedRows>();
auto *tmp_grad_slr = dst_var->GetMutable<framework::SelectedRows>();
tmp_grad_slr->set_rows(src_slr.rows());
tmp_grad_slr->set_height(src_slr.height());
auto &src_t = src_slr.value();
auto *dst_t = tmp_grad_slr->mutable_value();
framework::TensorCopy(src_t, cpu_place, dst_t);
} else {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -17,7 +17,9 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void InitializeVariable(Variable* var, proto::VarType::Type var_type); void InitializeVariable(Variable* var, proto::VarType::Type var_type);
void CopyVariable(const Variable& src_var, Variable* dst_var);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle

@ -30,7 +30,7 @@ if(WITH_GRPC)
else() else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc) set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc parameter_send.cc parameter_recv.cc communicator.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib) set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
@ -50,8 +50,12 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL) DEPS ${RPC_DEPS} executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory)
cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv)
cc_test(communicator_test SRCS communicator_test.cc DEPS communicator)
if(WITH_GPU) if(WITH_GPU)
cc_test(collective_server_test SRCS collective_server_test.cc cc_test(collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor ${RPC_DEPS} DEPS sendrecvop_rpc executor ${RPC_DEPS}

@ -0,0 +1,213 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/distributed/communicator.h"
#include <gflags/gflags.h>
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
DEFINE_bool(communicator_independent_recv_thread, true,
"use an independent to recv vars from parameter server");
DEFINE_int32(communicator_send_queue_size, 20,
"queue size to recv gradient before send");
DEFINE_int32(communicator_max_send_grad_num_before_recv, 20,
"max grad num to send before recv parameters");
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
DEFINE_int32(communicator_max_merge_var_num, 20,
"max var num to merge and send");
DEFINE_bool(communicator_fake_rpc, false,
"fake mode does not really send any thing");
namespace paddle {
namespace operators {
namespace distributed {
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_;
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
// get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread;
VLOG(0) << "communicator_send_queue_size: "
<< FLAGS_communicator_send_queue_size;
VLOG(0) << "communicator_max_send_grad_num_before_recv: "
<< FLAGS_communicator_max_send_grad_num_before_recv;
VLOG(0) << "communicator_thread_pool_size: "
<< FLAGS_communicator_thread_pool_size;
VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
}
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
Communicator::~Communicator() {
VLOG(3) << "~Communicator";
running_ = false;
if (send_thread_) send_thread_->join();
if (recv_thread_) recv_thread_->join();
VLOG(3) << "~Communicator done";
}
void Communicator::SendThread() {
VLOG(3) << "SendThread start!";
while (running_) {
std::vector<std::future<void>> task_futures;
task_futures.reserve(send_varname_to_ctx_.size());
VLOG(3) << "run send graph";
auto before_run_send_graph = GetCurrentUS();
for (auto &iter : send_varname_to_queue_) {
auto &var_name = iter.first;
auto &var_queue = iter.second;
if (var_queue->Size() > 0) {
auto send_task = [this, &var_name, &var_queue] {
VLOG(3) << var_name << " merge and send";
std::vector<std::shared_ptr<Variable>> vars;
size_t merged_var_num = 0;
while (var_queue->Size() > 0 &&
merged_var_num < FLAGS_communicator_max_merge_var_num) {
vars.push_back(var_queue->Pop());
// only count the send number of the first var
if (var_name == send_varname_to_queue_.begin()->first) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
merged_var_num++;
}
auto before_merge = GetCurrentUS();
MergeVars(var_name, vars, send_scope_.get());
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << var_name << " use time "
<< after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true);
}
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
};
task_futures.emplace_back(
send_threadpool_->enqueue(std::move(send_task)));
} else {
VLOG(3) << var_name << " queue empty";
}
}
for (auto &task_f : task_futures) {
task_f.wait();
}
auto after_run_send_graph = GetCurrentUS();
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
if (send_graph_use_time > 100) {
VLOG(1) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
}
if (!FLAGS_communicator_independent_recv_thread) {
RecvAll();
}
}
}
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(3) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
auto after_recv = GetCurrentUS();
VLOG(1) << "run recv graph use time " << after_recv - before_send;
}
void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
auto grad_num = grad_num_.load();
if (grad_num > FLAGS_communicator_max_send_grad_num_before_recv) {
VLOG(1) << "current grad num " << grad_num;
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}
void Communicator::Send(const std::string &var_name,
const framework::Scope &scope) {
VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name
auto *grad_var = scope.FindVar(var_name);
PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
auto tmp_grad_var = std::make_shared<Variable>();
framework::CopyVariable(*grad_var, tmp_grad_var.get());
auto &queue = send_varname_to_queue_.at(var_name);
VLOG(3) << "send " << var_name << " queue size " << queue->Size();
queue->Push(tmp_grad_var);
}
Communicator *Communicator::GetInstance() { return communicator_.get(); }
void Communicator::Start() {
running_ = true;
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
if (FLAGS_communicator_independent_recv_thread) {
recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle

@ -0,0 +1,219 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace distributed {
using Scope = framework::Scope;
using Variable = framework::Variable;
template <typename T>
class BlockingQueue {
public:
explicit BlockingQueue(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0, "The capacity must be greater than 0.");
}
bool Push(const T& elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.push_back(elem);
}
cv_.notify_one();
return true;
}
bool Push(T&& elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return queue_.size() < capacity_; });
PADDLE_ENFORCE_LT(queue_.size(), capacity_);
queue_.emplace_back(std::move(elem));
}
cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
inline void MergeVars(const std::string& var_name,
const std::vector<std::shared_ptr<Variable>>& vars,
Scope* scope) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto& var0 = vars[0];
auto* out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor " << dims;
// init output tensor
auto* out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<float>(dims, cpu_place);
// check the input dims
for (auto& var : vars) {
auto& var_t = var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(var_t.dims(), dims, "should have the same dims");
}
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, float>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
// sum all vars to out
auto result = EigenVector<float>::Flatten(*out_t);
for (auto& var : vars) {
auto& in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<float>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto& slr0 = var0->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.reserve(vars.size());
for (auto& var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
auto dev_ctx = paddle::platform::CPUDeviceContext();
merge_add(dev_ctx, inputs, out_slr, false);
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims();
} else {
PADDLE_THROW("unsupported var type!");
}
}
using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator {
public:
Communicator(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope);
~Communicator();
void Start();
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
private:
// recv all parameter
void RecvAll();
void SendThread();
void RecvThread();
bool running_ = false;
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::unique_ptr<std::thread> send_thread_;
std::unique_ptr<std::thread> recv_thread_;
Scope* recv_scope_; // should be global scope
std::unique_ptr<Scope> send_scope_; // an independent scope
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
// the following code is for initialize the commnunicator
public:
static void Init(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
InitImpl(send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
}
static Communicator* GetInstance();
private:
// Init is called by GetInstance.
static void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_ == nullptr) {
communicator_.reset(new Communicator(send_varname_to_ctx,
recv_varname_to_ctx, recv_scope));
}
}
private:
static std::once_flag init_flag_;
static std::unique_ptr<Communicator> communicator_;
};
} // namespace distributed
} // namespace operators
} // namespace paddle

@ -0,0 +1,110 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "paddle/fluid/operators/distributed/communicator.h"
namespace paddle {
namespace operators {
namespace distributed {
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
TEST(communicator, merge_lod_tensors) {
auto cpu_place = platform::CPUPlace();
auto dims = framework::make_ddim({2, 3});
std::vector<std::shared_ptr<framework::Variable>> in_vars;
float out_value = 0;
for (auto i = 0; i < 10; ++i) {
auto var = std::make_shared<Variable>();
in_vars.emplace_back(var);
auto *tensor = var->GetMutable<LoDTensor>();
auto *data = tensor->mutable_data<float>(dims, cpu_place);
for (auto j = 0; j < tensor->numel(); ++j) {
data[j] = static_cast<float>(i);
}
out_value += static_cast<float>(i);
}
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get());
}
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
auto *out_data = out_tensor.data<float>();
ASSERT_EQ(out_tensor.dims(), dims);
for (auto i = 0; i < out_tensor.numel(); ++i) {
ASSERT_EQ(out_data[i], out_value);
}
}
TEST(communicator, merge_selected_rows) {
auto cpu_place = platform::CPUPlace();
int64_t width = 10;
std::vector<std::shared_ptr<framework::Variable>> in_vars;
const int64_t height = 100;
for (auto i = 0; i < 10; ++i) {
std::vector<int64_t> rows;
for (auto k = 0; k <= i; ++k) {
rows.push_back(k);
}
auto var = std::make_shared<Variable>();
in_vars.emplace_back(var);
auto *slr = var->GetMutable<SelectedRows>();
slr->set_height(height);
slr->set_rows(rows);
auto dims =
framework::make_ddim({static_cast<int64_t>(rows.size()), width});
auto *data = slr->mutable_value()->mutable_data<float>(dims, cpu_place);
for (auto i = 0; i < rows.size(); ++i) {
for (auto j = 0; j < width; ++j) {
data[i * width + j] = static_cast<float>(rows[i]);
}
}
}
const std::string out_name = "Out";
std::unique_ptr<framework::Scope> scope;
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get());
}
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
auto &out_t = out_slr.value();
auto *out_data = out_t.data<float>();
ASSERT_EQ(out_t.dims(), framework::make_ddim({10, width}));
std::vector<float> out_values;
out_values.reserve(10);
for (auto i = 0; i < 10; ++i) {
out_values.push_back(static_cast<float>(i * (10 - i)));
}
for (auto i = 0; i < out_slr.rows().size(); ++i) {
ASSERT_EQ(out_slr.rows()[i], i);
for (auto j = 0; j < width; ++j) {
ASSERT_EQ(out_data[i * width + j], out_values[i]);
}
}
}
} // namespace distributed
} // namespace operators
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save