|
|
|
@ -53,6 +53,10 @@ struct VarHandle : public VarHandleBase {
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct DummyVarHandle : public VarHandleBase {
|
|
|
|
|
std::string DebugString() const override { return "dummy"; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct DependencyVarHandle : public VarHandleBase {
|
|
|
|
|
std::string DebugString() const override { return "Dependency Variable"; }
|
|
|
|
|
};
|
|
|
|
@ -643,6 +647,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
member_->exception_.reset();
|
|
|
|
|
std::unordered_map<VarHandleBase *, GuardedBool> pending_vars;
|
|
|
|
|
std::unordered_map<OpHandle *, size_t> pending_ops;
|
|
|
|
|
std::vector<DummyVarHandle> dummy_vars;
|
|
|
|
|
|
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
|
for (auto &name_pair : place_pair.second) {
|
|
|
|
@ -696,17 +701,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
var->pending_ops_.emplace(op);
|
|
|
|
|
op->inputs_.emplace_back(var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dummy_vars.emplace_back();
|
|
|
|
|
auto *var = &dummy_vars.back();
|
|
|
|
|
op->outputs_.emplace_back(var);
|
|
|
|
|
var->generated_op_ = op;
|
|
|
|
|
pending_vars[var] = false;
|
|
|
|
|
|
|
|
|
|
pending_ops.insert({op, op->inputs_.size()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::future<void>> op_threads;
|
|
|
|
|
op_threads.reserve(pending_ops.size() + to_run.size());
|
|
|
|
|
|
|
|
|
|
for (auto *op : to_run) {
|
|
|
|
|
op_threads.emplace_back(RunOp(pending_vars, op));
|
|
|
|
|
RunOp(pending_vars, op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (!pending_ops.empty()) {
|
|
|
|
|
while (!pending_vars.empty()) {
|
|
|
|
|
VarHandleBase *ready_var = nullptr;
|
|
|
|
|
for (auto &pair : pending_vars) {
|
|
|
|
|
if (pair.second) {
|
|
|
|
@ -715,12 +724,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
}
|
|
|
|
|
if (ready_var == nullptr) {
|
|
|
|
|
// FIXME use conditional var instead of busy wait.
|
|
|
|
|
|
|
|
|
|
if (member_->exception_) {
|
|
|
|
|
throw * member_->exception_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(3) << pending_vars.size();
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
pending_vars.erase(ready_var);
|
|
|
|
@ -734,20 +740,16 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
}
|
|
|
|
|
for (auto *op : to_run) {
|
|
|
|
|
pending_ops.erase(op);
|
|
|
|
|
op_threads.emplace_back(RunOp(pending_vars, op));
|
|
|
|
|
RunOp(pending_vars, op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &t : op_threads) {
|
|
|
|
|
t.get(); // Join all workers
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fetch_ops.clear();
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
|
|
|
|
|
fetched_data->tensors_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::future<void> ParallelExecutor::RunOp(
|
|
|
|
|
void ParallelExecutor::RunOp(
|
|
|
|
|
std::unordered_map<VarHandleBase *, GuardedBool> &pending_vars,
|
|
|
|
|
OpHandle *op) const {
|
|
|
|
|
std::vector<GuardedBool *> *ready_buffer = new std::vector<GuardedBool *>();
|
|
|
|
@ -768,7 +770,7 @@ std::future<void> ParallelExecutor::RunOp(
|
|
|
|
|
LOG(FATAL) << "Unknown exception catched";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
return member_->pool_.enqueue(op_run);
|
|
|
|
|
member_->pool_.enqueue(op_run);
|
|
|
|
|
}
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|