|
|
|
@ -28,42 +28,79 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
struct OpHandle;
|
|
|
|
|
|
|
|
|
|
struct VarHandle {
|
|
|
|
|
struct VarHandleBase {
|
|
|
|
|
virtual ~VarHandleBase() {}
|
|
|
|
|
virtual std::string DebugString() const = 0;
|
|
|
|
|
|
|
|
|
|
OpHandle *generated_op_;
|
|
|
|
|
std::vector<OpHandle *> pending_ops_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct VarHandle : public VarHandleBase {
|
|
|
|
|
std::string DebugString() const override {
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << name_ << ":" << place_;
|
|
|
|
|
return ss.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t version_;
|
|
|
|
|
std::string name_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
OpHandle *generated_op_;
|
|
|
|
|
|
|
|
|
|
std::vector<OpHandle *> pending_ops_;
|
|
|
|
|
struct DependencyVarHandle : public VarHandleBase {
|
|
|
|
|
std::string DebugString() const override { return "Deps var"; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpHandle {
|
|
|
|
|
std::vector<VarHandle *> inputs_;
|
|
|
|
|
std::vector<VarHandle *> outputs_;
|
|
|
|
|
std::vector<VarHandleBase *> inputs_;
|
|
|
|
|
std::vector<VarHandleBase *> outputs_;
|
|
|
|
|
std::unordered_map<platform::Place, platform::DeviceContext *,
|
|
|
|
|
platform::PlaceHash>
|
|
|
|
|
dev_ctx_;
|
|
|
|
|
|
|
|
|
|
std::string DebugString() {
|
|
|
|
|
std::stringstream ss;
|
|
|
|
|
ss << "(";
|
|
|
|
|
for (auto *var : inputs_) {
|
|
|
|
|
ss << var->name_ << ":" << var->place_ << ", ";
|
|
|
|
|
ss << var->DebugString() << ", ";
|
|
|
|
|
}
|
|
|
|
|
ss << ") --> (";
|
|
|
|
|
for (auto *var : outputs_) {
|
|
|
|
|
ss << var->name_ << ":" << var->place_ << ", ";
|
|
|
|
|
ss << var->DebugString() << ", ";
|
|
|
|
|
}
|
|
|
|
|
ss << ")\n";
|
|
|
|
|
return ss.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual ~OpHandle() {}
|
|
|
|
|
|
|
|
|
|
virtual void Run() {}
|
|
|
|
|
virtual void Wait() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
std::unique_ptr<OperatorBase> op_;
|
|
|
|
|
Scope *scope_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
|
|
|
|
|
explicit ComputationOpHandle(const OpDesc &op_desc)
|
|
|
|
|
: op_(framework::OpRegistry::CreateOp(op_desc)) {}
|
|
|
|
|
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
|
|
|
|
|
platform::Place place)
|
|
|
|
|
: op_(framework::OpRegistry::CreateOp(op_desc)),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
place_(place) {}
|
|
|
|
|
|
|
|
|
|
void Run() override {
|
|
|
|
|
// Wait other op if necessary
|
|
|
|
|
auto *cur_ctx = dev_ctx_[place_];
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
|
|
|
|
|
in->generated_op_->Wait();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_->Run(*scope_, place_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ScaleLossGradOpHandle : public OpHandle {};
|
|
|
|
@ -122,12 +159,27 @@ class ParallelExecutorPrivate {
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
|
|
|
|
|
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
|
|
|
|
|
return const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(place));
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
return GetNCCLCtx(place).ctx_.get();
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Not compiled with CUDA")
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
platform::Place main_place_;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<platform::Place,
|
|
|
|
|
std::unordered_map<std::string, std::map<int, VarHandle>>,
|
|
|
|
|
platform::PlaceHash>
|
|
|
|
|
vars_;
|
|
|
|
|
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpHandle>> ops_;
|
|
|
|
|
|
|
|
|
|
ThreadPool pool_;
|
|
|
|
@ -170,7 +222,7 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
|
void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
|
const ProgramDesc &main_program, const std::string &loss_var_name) const {
|
|
|
|
|
std::unordered_set<std::__cxx11::string> grads;
|
|
|
|
|
std::unordered_set<std::string> grads;
|
|
|
|
|
for (auto &each_param : params) {
|
|
|
|
|
grads.insert(each_param + "@GRAD");
|
|
|
|
|
}
|
|
|
|
@ -188,8 +240,11 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &pair : member_->local_scopes_) {
|
|
|
|
|
member_->ops_.emplace_back(new ComputationOpHandle(*op));
|
|
|
|
|
member_->ops_.emplace_back(
|
|
|
|
|
new ComputationOpHandle(*op, pair.second, pair.first));
|
|
|
|
|
auto *op_handle = member_->ops_.back().get();
|
|
|
|
|
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(pair.first));
|
|
|
|
|
|
|
|
|
|
auto var_names = op->InputArgumentNames();
|
|
|
|
|
|
|
|
|
@ -210,8 +265,11 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
if (var_names.size() == 1 && var_names[0] == loss_var_name) {
|
|
|
|
|
// Insert ScaleCost OpHandle
|
|
|
|
|
member_->ops_.emplace_back(new ScaleLossGradOpHandle());
|
|
|
|
|
|
|
|
|
|
op_handle = member_->ops_.back().get();
|
|
|
|
|
|
|
|
|
|
op_handle->dev_ctx_[pair.first] =
|
|
|
|
|
member_->CommunicationDevCtx(pair.first);
|
|
|
|
|
|
|
|
|
|
auto &place = pair.first;
|
|
|
|
|
VarHandle *loss = GetVarHandle(loss_var_name, place);
|
|
|
|
|
loss->pending_ops_.emplace_back(op_handle);
|
|
|
|
@ -251,11 +309,54 @@ void ParallelExecutor::ConstructDependencyGraph(
|
|
|
|
|
var.name_ = og;
|
|
|
|
|
var.version_ = vars.size() - 1;
|
|
|
|
|
op_handle->outputs_.emplace_back(&var);
|
|
|
|
|
|
|
|
|
|
for (auto &pair : member_->local_scopes_) {
|
|
|
|
|
op_handle->dev_ctx_[pair.first] =
|
|
|
|
|
member_->CommunicationDevCtx(pair.first);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Dependency graph has been constructed. However, there are still data
|
|
|
|
|
* harzaeds need to be handled.
|
|
|
|
|
*
|
|
|
|
|
* We only handle write after read(WAR), since it should not have a write
|
|
|
|
|
* after write in program. If there are write after write operators, we need
|
|
|
|
|
* prune them.
|
|
|
|
|
*
|
|
|
|
|
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
|
for (auto &name_pair : place_pair.second) {
|
|
|
|
|
if (name_pair.second.size() <= 1) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto it_new = name_pair.second.rbegin();
|
|
|
|
|
auto it_old = name_pair.second.rbegin();
|
|
|
|
|
++it_old;
|
|
|
|
|
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
|
|
|
|
|
auto *write_op = it_new->second.generated_op_;
|
|
|
|
|
auto &read_ops = it_old->second.pending_ops_;
|
|
|
|
|
|
|
|
|
|
for (auto *read_op : read_ops) {
|
|
|
|
|
// Manually add a dependency var from read_op to write_op;
|
|
|
|
|
|
|
|
|
|
auto *dep_var = new DependencyVarHandle();
|
|
|
|
|
dep_var->generated_op_ = read_op;
|
|
|
|
|
read_op->outputs_.emplace_back(dep_var);
|
|
|
|
|
|
|
|
|
|
dep_var->pending_ops_.emplace_back(write_op);
|
|
|
|
|
write_op->inputs_.emplace_back(dep_var);
|
|
|
|
|
member_->dep_vars_.emplace(dep_var);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParallelExecutor::GenerateVar(OpHandle *op_handle,
|
|
|
|
@ -349,7 +450,7 @@ std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
|
// Version --> VarHandle
|
|
|
|
|
|
|
|
|
|
std::unordered_map<VarHandle *, bool> pending_vars;
|
|
|
|
|
std::unordered_map<VarHandleBase *, bool> pending_vars;
|
|
|
|
|
std::unordered_map<OpHandle *, size_t> pending_ops;
|
|
|
|
|
|
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
@ -361,12 +462,16 @@ std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &var : member_->dep_vars_) {
|
|
|
|
|
pending_vars[var.get()] = var->generated_op_ == nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &op : member_->ops_) {
|
|
|
|
|
pending_ops.insert({op.get(), op->inputs_.size()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (!pending_ops.empty()) {
|
|
|
|
|
VarHandle *ready_var = nullptr;
|
|
|
|
|
VarHandleBase *ready_var = nullptr;
|
|
|
|
|
for (auto &pair : pending_vars) {
|
|
|
|
|
if (pair.second) {
|
|
|
|
|
ready_var = pair.first;
|
|
|
|
@ -400,7 +505,7 @@ std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
|
|
|
|
|
|
auto op_run = [ready_buffer, op] {
|
|
|
|
|
// TODO(yy) Check Previous Op has same dev ctx.
|
|
|
|
|
LOG(INFO) << "Run " << op->DebugString();
|
|
|
|
|
op->Run();
|
|
|
|
|
for (auto *ready : ready_buffer) {
|
|
|
|
|
*ready = true;
|
|
|
|
|
}
|
|
|
|
|