You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/parallel_executor.cc

520 lines
15 KiB

7 years ago
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include "lod_tensor.h"
#include "op_registry.h"
7 years ago
#include "threadpool.h"
7 years ago
namespace paddle {
namespace framework {
7 years ago
#ifdef PADDLE_WITH_CUDA
// FIXME: CHECK the return value of x;
#define NCCL_INVOKE(x) x
#endif
struct OpHandle;
struct VarHandleBase {
virtual ~VarHandleBase() {}
virtual std::string DebugString() const = 0;
OpHandle *generated_op_;
std::vector<OpHandle *> pending_ops_;
};
struct VarHandle : public VarHandleBase {
std::string DebugString() const override {
std::stringstream ss;
ss << name_ << ":" << place_;
return ss.str();
}
size_t version_;
std::string name_;
platform::Place place_;
};
struct DependencyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "Deps var"; }
};
struct OpHandle {
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctx_;
std::string DebugString() {
std::stringstream ss;
ss << "(";
for (auto *var : inputs_) {
ss << var->DebugString() << ", ";
}
ss << ") --> (";
for (auto *var : outputs_) {
ss << var->DebugString() << ", ";
}
ss << ")\n";
return ss.str();
}
virtual ~OpHandle() {}
virtual void Run() {}
virtual void Wait() {}
};
struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
7 years ago
explicit ComputationOpHandle(const OpDesc &op_desc, platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
7 years ago
scope_(nullptr),
place_(place) {}
void Run() override {
// Wait other op if necessary
7 years ago
LOG(INFO) << DebugString();
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
in->generated_op_->Wait();
}
}
op_->Run(*scope_, place_);
}
};
struct ScaleLossGradOpHandle : public OpHandle {};
struct NCCLAllReduceOpHandle : public OpHandle {};
class ParallelExecutorPrivate {
public:
7 years ago
explicit ParallelExecutorPrivate(size_t num_threads = 12)
: pool_(num_threads) {}
std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
local_scopes_;
7 years ago
#ifdef PADDLE_WITH_CUDA
struct NCCLContext {
std::unique_ptr<platform::CUDADeviceContext> ctx_;
ncclComm_t comm;
explicit NCCLContext(int dev_id) {
ctx_.reset(new platform::CUDADeviceContext(platform::CUDAPlace(dev_id)));
}
cudaStream_t stream() const { return ctx_->stream(); }
int device_id() const {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
}
static void InitNCCLContext(std::map<int, NCCLContext> &contexts) {
std::vector<ncclComm_t> comms;
std::vector<int> devs;
comms.resize(contexts.size());
devs.reserve(contexts.size());
for (auto &ctx : contexts) {
devs.push_back(ctx.first);
}
NCCL_INVOKE(platform::dynload::ncclCommInitAll(
&comms[0], static_cast<int>(contexts.size()), &devs[0]));
int i = 0;
for (auto &ctx : contexts) {
ctx.second.comm = comms[i++];
}
}
};
std::map<int, NCCLContext> communication_streams_;
NCCLContext &GetNCCLCtx(platform::Place p) {
int dev_id = boost::get<platform::CUDAPlace>(p).device;
return communication_streams_.at(dev_id);
}
#endif
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
return const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
} else {
#ifdef PADDLE_WITH_CUDA
return GetNCCLCtx(place).ctx_.get();
#else
PADDLE_THROW("Not compiled with CUDA")
#endif
}
}
platform::Place main_place_;
std::unordered_map<platform::Place,
std::unordered_map<std::string, std::map<int, VarHandle>>,
platform::PlaceHash>
vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandle>> ops_;
7 years ago
ThreadPool pool_;
};
// TODO(yy): Move this function somewhere
ncclDataType_t ToNCCLDataType(std::type_index type) {
// FIXME!!
return ncclFloat;
}
ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const ProgramDesc &startup_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope)
: member_(new ParallelExecutorPrivate()) {
// Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]);
exe.Run(startup_program, scope, 0);
// Create local scopes
for (auto &place : places) {
member_->local_scopes_[place] = &scope->NewScope();
}
member_->main_place_ = places[0];
// Bcast Parameters to all GPUs
7 years ago
if (platform::is_gpu_place(member_->main_place_) &&
member_->local_scopes_.size() != 1) { // Is CUDA
BuildNCCLCommunicator();
BCastParamsToGPUs(startup_program);
}
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
ConstructDependencyGraph(params, main_program, loss_var_name);
}
void ParallelExecutor::ConstructDependencyGraph(
const std::unordered_set<std::string> &params,
const ProgramDesc &main_program, const std::string &loss_var_name) const {
std::unordered_set<std::string> grads;
for (auto &each_param : params) {
grads.insert(each_param + "@GRAD");
}
bool is_forwarding = true;
for (auto *op : main_program.Block(0).AllOps()) {
bool change_forward = false;
if (!is_forwarding) {
// FIXME(yy): Do not hard code like this
if (op->OutputArgumentNames().size() == 1 &&
op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") {
continue; // Drop fill 1. for backward coeff;
}
}
for (auto &pair : member_->local_scopes_) {
7 years ago
member_->ops_.emplace_back(new ComputationOpHandle(*op, pair.first));
auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
auto &place = pair.first;
VarHandle *var = GetVarHandle(each_var_name, place);
op_handle->inputs_.emplace_back(var);
7 years ago
var->pending_ops_.emplace_back(op_handle);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
auto &place = pair.first;
GenerateVar(op_handle, each_var_name, place);
}
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle());
op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first);
auto &place = pair.first;
VarHandle *loss = GetVarHandle(loss_var_name, place);
7 years ago
loss->pending_ops_.emplace_back(op_handle);
op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", place);
change_forward = true;
LOG(INFO) << "Scale Loss " << op_handle->DebugString();
}
}
}
if (change_forward) {
is_forwarding = false;
}
if (!is_forwarding) {
auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) {
if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
member_->ops_.emplace_back(new NCCLAllReduceOpHandle());
auto *op_handle = member_->ops_.back().get();
for (auto &pair : member_->local_scopes_) {
auto &place = pair.first;
auto &vars = member_->vars_[place][og];
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->inputs_.emplace_back(prev_grad);
7 years ago
prev_grad->pending_ops_.emplace_back(op_handle);
auto &var = vars[vars.size()];
var.place_ = place;
var.generated_op_ = op_handle;
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var);
for (auto &pair : member_->local_scopes_) {
op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first);
}
}
}
}
}
}
/**
* Dependency graph has been constructed. However, there are still data
* harzaeds need to be handled.
*
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) {
if (name_pair.second.size() <= 1) {
return;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
auto *dep_var = new DependencyVarHandle();
dep_var->generated_op_ = read_op;
read_op->outputs_.emplace_back(dep_var);
dep_var->pending_ops_.emplace_back(write_op);
write_op->inputs_.emplace_back(dep_var);
member_->dep_vars_.emplace(dep_var);
}
}
}
}
}
void ParallelExecutor::GenerateVar(OpHandle *op_handle,
const std::string &each_var_name,
const platform::Place &place) const {
auto &vars = member_->vars_[place][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name;
var.place_ = place;
op_handle->outputs_.emplace_back(&var);
}
VarHandle *ParallelExecutor::GetVarHandle(const std::string &each_var_name,
const platform::Place &place) const {
auto &var_holders = member_->vars_[place];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const {
7 years ago
#ifdef PADDLE_WITH_CUDA
auto *main_scope = member_->local_scopes_[member_->main_place_];
7 years ago
for (auto *var_desc : startup_program.Block(0).AllVars()) {
if (var_desc->GetType() == proto::VarType::LOD_TENSOR) {
auto &main_tensor =
main_scope->FindVar(var_desc->Name())->Get<LoDTensor>();
ncclDataType_t data_type = ToNCCLDataType(main_tensor.type());
auto &dims = main_tensor.dims();
size_t numel = main_tensor.numel();
7 years ago
std::vector<std::pair<void *, ParallelExecutorPrivate::NCCLContext *>>
mems;
mems.emplace_back(const_cast<void *>(main_tensor.data<void>()),
&member_->GetNCCLCtx(member_->main_place_));
for (auto &pair : member_->local_scopes_) {
if (pair.first == member_->main_place_) {
continue;
}
auto local_scope = pair.second;
auto *t = local_scope->Var(var_desc->Name())->GetMutable<LoDTensor>();
t->Resize(dims);
mems.emplace_back(t->mutable_data(pair.first, main_tensor.type()),
7 years ago
&member_->GetNCCLCtx(member_->main_place_));
}
// TODO(yy): Invoke ncclBCast here. mems, numel, data_type. The mems[0]
// is the src, rests are dests.
(void)(data_type);
(void)(numel);
7 years ago
}
}
#else
PADDLE_THROW("Not compiled with CUDA");
#endif
}
7 years ago
void ParallelExecutor::BuildNCCLCommunicator() const {
#ifdef PADDLE_WITH_CUDA
for (auto &place_pair : member_->local_scopes_) {
auto place = place_pair.first;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
7 years ago
member_->communication_streams_.emplace(
dev_id, ParallelExecutorPrivate::NCCLContext(dev_id));
}
7 years ago
ParallelExecutorPrivate::NCCLContext::InitNCCLContext(
member_->communication_streams_);
#endif
}
std::vector<LoDTensor> ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// Version --> VarHandle
7 years ago
std::unordered_map<VarHandleBase *, bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) {
for (auto &version_pair : name_pair.second) {
7 years ago
pending_vars[&version_pair.second] =
version_pair.second.generated_op_ == nullptr;
}
}
}
for (auto &var : member_->dep_vars_) {
pending_vars[var.get()] = var->generated_op_ == nullptr;
}
for (auto &op : member_->ops_) {
pending_ops.insert({op.get(), op->inputs_.size()});
}
7 years ago
while (!pending_ops.empty()) {
VarHandleBase *ready_var = nullptr;
7 years ago
for (auto &pair : pending_vars) {
if (pair.second) {
ready_var = pair.first;
}
}
7 years ago
if (ready_var == nullptr) {
member_->pool_.Wait(); // Wait thread pool;
continue;
}
7 years ago
pending_vars.erase(ready_var);
std::vector<OpHandle *> to_run;
7 years ago
for (auto *op : ready_var->pending_ops_) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
to_run.emplace_back(op);
}
}
for (auto *op : to_run) {
pending_ops.erase(op);
7 years ago
std::vector<bool *> ready_buffer;
for (auto *var : op->outputs_) {
ready_buffer.emplace_back(&pending_vars[var]);
}
7 years ago
auto op_run = [ready_buffer, op] {
// TODO(yy) Check Previous Op has same dev ctx.
op->Run();
7 years ago
for (auto *ready : ready_buffer) {
*ready = true;
}
};
7 years ago
member_->pool_.Run(op_run);
}
}
return std::vector<LoDTensor>();
}
} // namespace framework
7 years ago
} // namespace paddle