|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
#include "paddle/fluid/framework/parallel_executor.h"
|
|
|
|
#include "paddle/fluid/framework/parallel_executor.h"
|
|
|
|
#include "lod_tensor.h"
|
|
|
|
#include "lod_tensor.h"
|
|
|
|
#include "op_registry.h"
|
|
|
|
#include "op_registry.h"
|
|
|
|
|
|
|
|
#include "threadpool.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace framework {
|
|
|
|
namespace framework {
|
|
|
@ -34,7 +35,6 @@ struct VarHandle {
|
|
|
|
struct OpHandle {
|
|
|
|
struct OpHandle {
|
|
|
|
std::vector<VarHandle *> inputs_;
|
|
|
|
std::vector<VarHandle *> inputs_;
|
|
|
|
std::vector<VarHandle *> outputs_;
|
|
|
|
std::vector<VarHandle *> outputs_;
|
|
|
|
platform::DeviceContext *dev_ctx_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::string DebugString() {
|
|
|
|
std::string DebugString() {
|
|
|
|
std::stringstream ss;
|
|
|
|
std::stringstream ss;
|
|
|
@ -66,6 +66,9 @@ struct NCCLAllReduceOpHandle : public OpHandle {};
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
|
|
|
|
explicit ParallelExecutorPrivate(size_t num_threads = 12)
|
|
|
|
|
|
|
|
: pool_(num_threads) {}
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
|
|
|
|
std::unordered_map<platform::Place, Scope *, platform::PlaceHash>
|
|
|
|
local_scopes_;
|
|
|
|
local_scopes_;
|
|
|
|
std::unordered_map<platform::Place, platform::CUDADeviceContext,
|
|
|
|
std::unordered_map<platform::Place, platform::CUDADeviceContext,
|
|
|
@ -78,6 +81,8 @@ class ParallelExecutorPrivate {
|
|
|
|
platform::PlaceHash>
|
|
|
|
platform::PlaceHash>
|
|
|
|
vars_;
|
|
|
|
vars_;
|
|
|
|
std::vector<std::unique_ptr<OpHandle>> ops_;
|
|
|
|
std::vector<std::unique_ptr<OpHandle>> ops_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ThreadPool pool_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(yy): Move this function somewhere
|
|
|
|
// TODO(yy): Move this function somewhere
|
|
|
@ -285,13 +290,15 @@ void ParallelExecutor::BCastParamsToGPUs(
|
|
|
|
std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
const std::vector<std::string> &fetch_tensors) {
|
|
|
|
// Version --> VarHandle
|
|
|
|
// Version --> VarHandle
|
|
|
|
std::unordered_set<VarHandle *> pending_vars;
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_map<VarHandle *, bool> pending_vars;
|
|
|
|
std::unordered_map<OpHandle *, size_t> pending_ops;
|
|
|
|
std::unordered_map<OpHandle *, size_t> pending_ops;
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
for (auto &place_pair : member_->vars_) {
|
|
|
|
for (auto &name_pair : place_pair.second) {
|
|
|
|
for (auto &name_pair : place_pair.second) {
|
|
|
|
for (auto &version_pair : name_pair.second) {
|
|
|
|
for (auto &version_pair : name_pair.second) {
|
|
|
|
pending_vars.insert(&version_pair.second);
|
|
|
|
pending_vars[&version_pair.second] =
|
|
|
|
|
|
|
|
version_pair.second.generated_op_ == nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -300,56 +307,50 @@ std::vector<LoDTensor> ParallelExecutor::Run(
|
|
|
|
pending_ops.insert({op.get(), op->inputs_.size()});
|
|
|
|
pending_ops.insert({op.get(), op->inputs_.size()});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<OpHandle *> complete_op;
|
|
|
|
while (!pending_ops.empty()) {
|
|
|
|
|
|
|
|
VarHandle *ready_var = nullptr;
|
|
|
|
size_t num_op = pending_ops.size();
|
|
|
|
for (auto &pair : pending_vars) {
|
|
|
|
|
|
|
|
if (pair.second) {
|
|
|
|
while (complete_op.size() != num_op) {
|
|
|
|
ready_var = pair.first;
|
|
|
|
std::vector<VarHandle *> to_remove;
|
|
|
|
|
|
|
|
for (auto &var : pending_vars) {
|
|
|
|
|
|
|
|
if (var->generated_op_ == nullptr ||
|
|
|
|
|
|
|
|
complete_op.count(var->generated_op_) != 0) {
|
|
|
|
|
|
|
|
to_remove.push_back(var);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto *var : to_remove) {
|
|
|
|
|
|
|
|
pending_vars.erase(var);
|
|
|
|
if (ready_var == nullptr) {
|
|
|
|
|
|
|
|
member_->pool_.Wait(); // Wait thread pool;
|
|
|
|
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pending_vars.erase(ready_var);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<OpHandle *> to_run;
|
|
|
|
std::vector<OpHandle *> to_run;
|
|
|
|
for (auto *var : to_remove) {
|
|
|
|
|
|
|
|
for (auto *op : var->pending_ops_) {
|
|
|
|
for (auto *op : ready_var->pending_ops_) {
|
|
|
|
if (var->name_ == "mean_0.tmp_0@GRAD") {
|
|
|
|
auto &deps = pending_ops[op];
|
|
|
|
LOG(INFO) << op->DebugString();
|
|
|
|
--deps;
|
|
|
|
}
|
|
|
|
if (deps == 0) {
|
|
|
|
auto &num = pending_ops[op];
|
|
|
|
|
|
|
|
--num;
|
|
|
|
|
|
|
|
if (num == 0) {
|
|
|
|
|
|
|
|
to_run.emplace_back(op);
|
|
|
|
to_run.emplace_back(op);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto *op : to_run) {
|
|
|
|
for (auto *op : to_run) {
|
|
|
|
pending_ops.erase(op);
|
|
|
|
pending_ops.erase(op);
|
|
|
|
complete_op.insert(op);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (to_run.empty()) break;
|
|
|
|
std::vector<bool *> ready_buffer;
|
|
|
|
|
|
|
|
for (auto *var : op->outputs_) {
|
|
|
|
|
|
|
|
ready_buffer.emplace_back(&pending_vars[var]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(yy): Use thead pool to run OpHandle. Operators in ToRun can be
|
|
|
|
auto op_run = [ready_buffer, op] {
|
|
|
|
// paralleled. We can also use another schedule method. Just a demo here.
|
|
|
|
// TODO(yy) Check Previous Op has same dev ctx.
|
|
|
|
|
|
|
|
LOG(INFO) << "Run " << op->DebugString();
|
|
|
|
|
|
|
|
for (auto *ready : ready_buffer) {
|
|
|
|
|
|
|
|
*ready = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
std::stringstream ss;
|
|
|
|
member_->pool_.Run(op_run);
|
|
|
|
ss << "\n";
|
|
|
|
|
|
|
|
for (auto *op : to_run) {
|
|
|
|
|
|
|
|
ss << op->DebugString() << "\n";
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ss << std::endl;
|
|
|
|
|
|
|
|
LOG(INFO) << ss.str();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(complete_op.size(), num_op);
|
|
|
|
|
|
|
|
return std::vector<LoDTensor>();
|
|
|
|
return std::vector<LoDTensor>();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|