commit
ff052c0e6f
@ -0,0 +1,175 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/details/fetch_op_handle.h"
|
||||||
|
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
|
||||||
|
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
|
||||||
|
const std::vector<platform::Place> &places,
|
||||||
|
std::unique_ptr<ir::Graph> &&graph)
|
||||||
|
: strategy_(strategy),
|
||||||
|
local_scopes_(local_scopes),
|
||||||
|
places_(places),
|
||||||
|
graph_(std::move(graph)),
|
||||||
|
pool_(strategy.num_threads_ +
|
||||||
|
1), // add one more thread for generate op_deps
|
||||||
|
fetch_ctxs_(places) {
|
||||||
|
auto &ops = graph_->Get<details::GraphOps>("ops");
|
||||||
|
|
||||||
|
for (auto &op : ops) {
|
||||||
|
int dep = static_cast<int>(op->NotReadyInputSize());
|
||||||
|
op_deps_.emplace(op.get(), dep);
|
||||||
|
if (dep == 0) {
|
||||||
|
bootstrap_ops_.emplace_back(op.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PrepareAtomicOpDeps();
|
||||||
|
}
|
||||||
|
|
||||||
|
FeedFetchList FastThreadedSSAGraphExecutor::Run(
|
||||||
|
const std::vector<std::string> &fetch_tensors) {
|
||||||
|
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
|
||||||
|
op_deps = atomic_op_deps_.get();
|
||||||
|
PrepareAtomicOpDeps();
|
||||||
|
|
||||||
|
paddle::framework::FeedFetchList fetches;
|
||||||
|
fetches.resize(fetch_tensors.size());
|
||||||
|
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
|
||||||
|
std::vector<std::unique_ptr<ir::Node>> fetch_nodes;
|
||||||
|
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
|
||||||
|
|
||||||
|
for (auto &fetch_var_name : fetch_tensors) {
|
||||||
|
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
|
||||||
|
auto it = var_map.find(fetch_var_name);
|
||||||
|
if (it != var_map.end()) {
|
||||||
|
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
|
||||||
|
auto &var_name = fetch_tensors[i];
|
||||||
|
auto fetched_var_it = fetched_vars.find(var_name);
|
||||||
|
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
|
||||||
|
"Cannot find fetched variable.(Perhaps the main_program "
|
||||||
|
"is not set to ParallelExecutor)");
|
||||||
|
|
||||||
|
auto &vars = fetched_var_it->second;
|
||||||
|
|
||||||
|
fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation));
|
||||||
|
auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i,
|
||||||
|
&local_scopes_);
|
||||||
|
fetch_ops.emplace_back(op);
|
||||||
|
|
||||||
|
for (auto &p : places_) {
|
||||||
|
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto *var : vars) {
|
||||||
|
op->AddInput(var);
|
||||||
|
}
|
||||||
|
|
||||||
|
(*op_deps)[op] = static_cast<int>(op->NotReadyInputSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_complete = 0;
|
||||||
|
remaining_ = 0;
|
||||||
|
BlockingQueue<size_t> complete_q;
|
||||||
|
for (auto op : bootstrap_ops_) {
|
||||||
|
RunOpAsync(op_deps.get(), op, &complete_q);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (num_complete != op_deps->size()) {
|
||||||
|
size_t num_comp = complete_q.Pop();
|
||||||
|
if (num_comp == -1UL) {
|
||||||
|
int remaining = 0;
|
||||||
|
while (true) {
|
||||||
|
remaining = remaining_;
|
||||||
|
if (remaining == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < remaining; ++i) {
|
||||||
|
complete_q.Pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
exception_.ReThrow();
|
||||||
|
}
|
||||||
|
num_complete += num_comp;
|
||||||
|
}
|
||||||
|
// Wait FetchOps.
|
||||||
|
if (!fetch_ops.empty()) {
|
||||||
|
fetch_ops.clear();
|
||||||
|
}
|
||||||
|
return fetches;
|
||||||
|
}
|
||||||
|
void FastThreadedSSAGraphExecutor::RunOpAsync(
|
||||||
|
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
|
||||||
|
OpHandleBase *op, BlockingQueue<size_t> *complete_q) {
|
||||||
|
++remaining_;
|
||||||
|
this->pool_.enqueue([=] {
|
||||||
|
OpHandleBase *op_to_run = op;
|
||||||
|
size_t complete = 0;
|
||||||
|
while (op_to_run != nullptr) {
|
||||||
|
try {
|
||||||
|
op_to_run->Run(strategy_.use_cuda_);
|
||||||
|
++complete;
|
||||||
|
} catch (...) {
|
||||||
|
exception_.Catch(std::current_exception());
|
||||||
|
--remaining_;
|
||||||
|
complete_q->Push(-1UL);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto &outputs = op_to_run->Outputs();
|
||||||
|
op_to_run = nullptr;
|
||||||
|
for (auto &output : outputs) {
|
||||||
|
for (auto &pending_op : output->PendingOps()) {
|
||||||
|
std::atomic<int> &deps = op_deps->at(pending_op);
|
||||||
|
if (deps.fetch_sub(1) == 1) { // pending_op ready
|
||||||
|
if (op_to_run == nullptr) {
|
||||||
|
op_to_run = pending_op;
|
||||||
|
} else {
|
||||||
|
this->RunOpAsync(op_deps, pending_op, complete_q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
--remaining_;
|
||||||
|
complete_q->Push(complete);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
|
||||||
|
atomic_op_deps_ = pool_.enqueue([&] {
|
||||||
|
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps =
|
||||||
|
new std::unordered_map<OpHandleBase *, std::atomic<int>>;
|
||||||
|
for (auto &pair : op_deps_) {
|
||||||
|
(*op_deps)[pair.first] = pair.second;
|
||||||
|
}
|
||||||
|
return std::unique_ptr<
|
||||||
|
std::unordered_map<OpHandleBase *, std::atomic<int>>>(op_deps);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ThreadPool.h"
|
||||||
|
#include "paddle/fluid/framework/blocking_queue.h"
|
||||||
|
#include "paddle/fluid/framework/details/exception_holder.h"
|
||||||
|
#include "paddle/fluid/framework/details/execution_strategy.h"
|
||||||
|
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class Scope;
|
||||||
|
namespace details {
|
||||||
|
|
||||||
|
class OpHandleBase;
|
||||||
|
class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
||||||
|
public:
|
||||||
|
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
|
||||||
|
const std::vector<Scope *> &local_scopes,
|
||||||
|
const std::vector<platform::Place> &places,
|
||||||
|
std::unique_ptr<ir::Graph> &&graph);
|
||||||
|
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
|
||||||
|
const ir::Graph &Graph() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ExecutionStrategy strategy_;
|
||||||
|
std::vector<Scope *> local_scopes_;
|
||||||
|
std::vector<platform::Place> places_;
|
||||||
|
std::unique_ptr<ir::Graph> graph_;
|
||||||
|
|
||||||
|
std::unordered_map<OpHandleBase *, int> op_deps_;
|
||||||
|
std::vector<OpHandleBase *> bootstrap_ops_;
|
||||||
|
|
||||||
|
::ThreadPool pool_;
|
||||||
|
platform::DeviceContextPool fetch_ctxs_;
|
||||||
|
std::atomic<int> remaining_;
|
||||||
|
|
||||||
|
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
|
||||||
|
OpHandleBase *op, BlockingQueue<size_t> *complete_q);
|
||||||
|
|
||||||
|
void PrepareAtomicOpDeps();
|
||||||
|
|
||||||
|
std::future<
|
||||||
|
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
|
||||||
|
atomic_op_deps_;
|
||||||
|
ExceptionHolder exception_;
|
||||||
|
};
|
||||||
|
} // namespace details
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
@ -1,4 +1,4 @@
|
|||||||
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
|
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
|
||||||
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
|
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
|
||||||
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
|
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
|
||||||
add_subdirectory(convert)
|
add_subdirectory(convert)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue