parent
c70b60dd70
commit
e3144393e3
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
|
||||
: graph_(std::move(graph)) {}
|
||||
|
||||
SSAGraphExecutor::~SSAGraphExecutor() {}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "paddle/fluid/framework/details/ssa_graph.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class SSAGraphExecutor {
|
||||
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
|
||||
|
||||
public:
|
||||
// Steal graph inside
|
||||
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
|
||||
|
||||
virtual ~SSAGraphExecutor();
|
||||
|
||||
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<SSAGraph> graph_;
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,192 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
|
||||
|
||||
#include "paddle/fluid/framework/details/fetch_op_handle.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
|
||||
size_t num_threads, bool use_event,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
std::unique_ptr<SSAGraph> &&graph)
|
||||
: SSAGraphExecutor(std::move(graph)),
|
||||
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
|
||||
local_scopes_(local_scopes),
|
||||
places_(places),
|
||||
fetch_ctxs_(places),
|
||||
use_event_(use_event) {}
|
||||
|
||||
FeedFetchList ThreadedSSAGraphExecutor::Run(
|
||||
const std::vector<std::string> &fetch_tensors) {
|
||||
std::unordered_map<OpHandleBase *, size_t> pending_ops;
|
||||
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
|
||||
std::unordered_set<OpHandleBase *> ready_ops;
|
||||
|
||||
auto InsertPendingVar = [&pending_vars](VarHandleBase &var) {
|
||||
pending_vars[&var] = var.generated_op_ == nullptr;
|
||||
};
|
||||
|
||||
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
|
||||
pending_ops.insert({&op_instance, op_instance.inputs_.size()});
|
||||
};
|
||||
|
||||
// Transform SSAGraph to pending_ops & pending_vars
|
||||
for (auto &var_map : graph_->vars_) {
|
||||
for (auto &name_pair : var_map) {
|
||||
for (auto &version_pair : name_pair.second) {
|
||||
InsertPendingVar(version_pair.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &var : graph_->dep_vars_) {
|
||||
InsertPendingVar(*var);
|
||||
}
|
||||
|
||||
for (auto &op : graph_->ops_) {
|
||||
if (op->inputs_.empty()) { // Special case, Op has no input.
|
||||
ready_ops.insert(op.get());
|
||||
} else {
|
||||
InsertPendingOp(*op);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2. Insert FetchOps
|
||||
std::vector<FetchOpHandle> fetch_ops;
|
||||
std::vector<DummyVarHandle> dummy_vars;
|
||||
FeedFetchList fetch_data(fetch_tensors.size());
|
||||
|
||||
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
||||
|
||||
for (auto &fetch_var_name : fetch_tensors) {
|
||||
for (auto &var_map : graph_->vars_) {
|
||||
auto it = var_map.find(fetch_var_name);
|
||||
if (it != var_map.end()) {
|
||||
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
|
||||
auto &var_name = fetch_tensors[i];
|
||||
auto &vars = fetched_vars[var_name];
|
||||
fetch_ops.emplace_back(&fetch_data, i, &local_scopes_);
|
||||
details::FetchOpHandle *op = &fetch_ops.back();
|
||||
|
||||
// FIXME: Use new device context
|
||||
for (auto &p : places_) {
|
||||
op->dev_ctx_[p] = fetch_ctxs_.Get(p);
|
||||
}
|
||||
|
||||
for (auto *var : vars) {
|
||||
op->AddInput(var);
|
||||
}
|
||||
|
||||
dummy_vars.emplace_back();
|
||||
auto *var = &dummy_vars.back();
|
||||
var->generated_op_ = nullptr;
|
||||
op->AddOutput(var);
|
||||
InsertPendingVar(*var);
|
||||
InsertPendingOp(*op);
|
||||
}
|
||||
|
||||
auto run_all_ready_ops = [&] {
|
||||
for (auto *op : ready_ops) {
|
||||
RunOp(pending_vars, op);
|
||||
}
|
||||
ready_ops.clear();
|
||||
};
|
||||
|
||||
// Step 3. Execution
|
||||
while (!pending_vars.empty()) {
|
||||
// 1. Run All Ready ops
|
||||
run_all_ready_ops();
|
||||
|
||||
// 2. Find ready variable
|
||||
VarHandleBase *ready_var = nullptr;
|
||||
for (auto &pair : pending_vars) {
|
||||
if (pair.second.load(std::memory_order_acquire)) {
|
||||
ready_var = pair.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// if there is no variable ready
|
||||
if (ready_var == nullptr) {
|
||||
// FIXME use conditional var instead of busy wait.
|
||||
// if there is an exception, throw it
|
||||
if (exception_) {
|
||||
throw * exception_;
|
||||
}
|
||||
// keep waiting the ready variables
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. Remove the dependency of ready_var.
|
||||
// Find the ready_ops after the ready_var.
|
||||
pending_vars.erase(ready_var);
|
||||
for (auto *op : ready_var->pending_ops_) {
|
||||
auto &deps = pending_ops[op];
|
||||
--deps;
|
||||
if (deps == 0) {
|
||||
ready_ops.insert(op);
|
||||
}
|
||||
}
|
||||
// Keep loop until all vars are ready.
|
||||
}
|
||||
|
||||
// Wait FetchOps.
|
||||
for (auto &fetch_op : fetch_ops) {
|
||||
fetch_op.WaitAndMergeCPUTensors();
|
||||
}
|
||||
|
||||
return fetch_data;
|
||||
}
|
||||
|
||||
void ThreadedSSAGraphExecutor::RunOp(
|
||||
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
|
||||
details::OpHandleBase *op) {
|
||||
std::vector<std::atomic<bool> *> *ready_buffer =
|
||||
new std::vector<std::atomic<bool> *>();
|
||||
for (auto *var : op->outputs_) {
|
||||
ready_buffer->emplace_back(&pending_vars[var]);
|
||||
}
|
||||
|
||||
auto op_run = [ready_buffer, op, this] {
|
||||
try {
|
||||
VLOG(10) << op->DebugString();
|
||||
op->Run(use_event_);
|
||||
for (auto *ready : *ready_buffer) {
|
||||
ready->store(true, std::memory_order_release);
|
||||
}
|
||||
delete ready_buffer;
|
||||
} catch (platform::EnforceNotMet ex) {
|
||||
exception_.reset(new platform::EnforceNotMet(ex));
|
||||
} catch (...) {
|
||||
LOG(FATAL) << "Unknown exception catched";
|
||||
}
|
||||
};
|
||||
if (pool_) {
|
||||
pool_->enqueue(op_run);
|
||||
} else {
|
||||
op_run();
|
||||
}
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,55 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ThreadPool.h" // ThreadPool in thrird party
|
||||
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class Scope;
|
||||
|
||||
namespace details {
|
||||
|
||||
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
||||
public:
|
||||
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
std::unique_ptr<SSAGraph> &&graph);
|
||||
|
||||
// Run a SSAGraph by a thread pool
|
||||
// Use topological sort algorithm
|
||||
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
|
||||
|
||||
~ThreadedSSAGraphExecutor() {}
|
||||
|
||||
private:
|
||||
void RunOp(
|
||||
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
|
||||
details::OpHandleBase *op);
|
||||
|
||||
private:
|
||||
std::unique_ptr<::ThreadPool> pool_;
|
||||
std::vector<Scope *> local_scopes_;
|
||||
std::vector<platform::Place> places_;
|
||||
platform::DeviceContextPool fetch_ctxs_;
|
||||
const bool use_event_;
|
||||
std::unique_ptr<platform::EnforceNotMet> exception_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue