commit
						ff052c0e6f
					
				| @ -0,0 +1,175 @@ | ||||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include "paddle/fluid/framework/details/fetch_op_handle.h" | ||||
| #include "paddle/fluid/framework/details/multi_devices_helper.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace framework { | ||||
| namespace details { | ||||
| 
 | ||||
| FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( | ||||
|     const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, | ||||
|     const std::vector<platform::Place> &places, | ||||
|     std::unique_ptr<ir::Graph> &&graph) | ||||
|     : strategy_(strategy), | ||||
|       local_scopes_(local_scopes), | ||||
|       places_(places), | ||||
|       graph_(std::move(graph)), | ||||
|       pool_(strategy.num_threads_ + | ||||
|             1),  // add one more thread for generate op_deps
 | ||||
|       fetch_ctxs_(places) { | ||||
|   auto &ops = graph_->Get<details::GraphOps>("ops"); | ||||
| 
 | ||||
|   for (auto &op : ops) { | ||||
|     int dep = static_cast<int>(op->NotReadyInputSize()); | ||||
|     op_deps_.emplace(op.get(), dep); | ||||
|     if (dep == 0) { | ||||
|       bootstrap_ops_.emplace_back(op.get()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   PrepareAtomicOpDeps(); | ||||
| } | ||||
| 
 | ||||
| FeedFetchList FastThreadedSSAGraphExecutor::Run( | ||||
|     const std::vector<std::string> &fetch_tensors) { | ||||
|   std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>> | ||||
|       op_deps = atomic_op_deps_.get(); | ||||
|   PrepareAtomicOpDeps(); | ||||
| 
 | ||||
|   paddle::framework::FeedFetchList fetches; | ||||
|   fetches.resize(fetch_tensors.size()); | ||||
|   std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; | ||||
|   std::vector<std::unique_ptr<ir::Node>> fetch_nodes; | ||||
|   std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops; | ||||
| 
 | ||||
|   for (auto &fetch_var_name : fetch_tensors) { | ||||
|     for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { | ||||
|       auto it = var_map.find(fetch_var_name); | ||||
|       if (it != var_map.end()) { | ||||
|         fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   for (size_t i = 0; i < fetch_tensors.size(); ++i) { | ||||
|     auto &var_name = fetch_tensors[i]; | ||||
|     auto fetched_var_it = fetched_vars.find(var_name); | ||||
|     PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(), | ||||
|                    "Cannot find fetched variable.(Perhaps the main_program " | ||||
|                    "is not set to ParallelExecutor)"); | ||||
| 
 | ||||
|     auto &vars = fetched_var_it->second; | ||||
| 
 | ||||
|     fetch_nodes.emplace_back(new ir::Node("fetch", ir::Node::Type::kOperation)); | ||||
|     auto *op = new FetchOpHandle(fetch_nodes.back().get(), &fetches, i, | ||||
|                                  &local_scopes_); | ||||
|     fetch_ops.emplace_back(op); | ||||
| 
 | ||||
|     for (auto &p : places_) { | ||||
|       op->SetDeviceContext(p, fetch_ctxs_.Get(p)); | ||||
|     } | ||||
| 
 | ||||
|     for (auto *var : vars) { | ||||
|       op->AddInput(var); | ||||
|     } | ||||
| 
 | ||||
|     (*op_deps)[op] = static_cast<int>(op->NotReadyInputSize()); | ||||
|   } | ||||
| 
 | ||||
|   size_t num_complete = 0; | ||||
|   remaining_ = 0; | ||||
|   BlockingQueue<size_t> complete_q; | ||||
|   for (auto op : bootstrap_ops_) { | ||||
|     RunOpAsync(op_deps.get(), op, &complete_q); | ||||
|   } | ||||
| 
 | ||||
|   while (num_complete != op_deps->size()) { | ||||
|     size_t num_comp = complete_q.Pop(); | ||||
|     if (num_comp == -1UL) { | ||||
|       int remaining = 0; | ||||
|       while (true) { | ||||
|         remaining = remaining_; | ||||
|         if (remaining == 0) { | ||||
|           break; | ||||
|         } | ||||
|         for (int i = 0; i < remaining; ++i) { | ||||
|           complete_q.Pop(); | ||||
|         } | ||||
|       } | ||||
|       exception_.ReThrow(); | ||||
|     } | ||||
|     num_complete += num_comp; | ||||
|   } | ||||
|   // Wait FetchOps.
 | ||||
|   if (!fetch_ops.empty()) { | ||||
|     fetch_ops.clear(); | ||||
|   } | ||||
|   return fetches; | ||||
| } | ||||
| void FastThreadedSSAGraphExecutor::RunOpAsync( | ||||
|     std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, | ||||
|     OpHandleBase *op, BlockingQueue<size_t> *complete_q) { | ||||
|   ++remaining_; | ||||
|   this->pool_.enqueue([=] { | ||||
|     OpHandleBase *op_to_run = op; | ||||
|     size_t complete = 0; | ||||
|     while (op_to_run != nullptr) { | ||||
|       try { | ||||
|         op_to_run->Run(strategy_.use_cuda_); | ||||
|         ++complete; | ||||
|       } catch (...) { | ||||
|         exception_.Catch(std::current_exception()); | ||||
|         --remaining_; | ||||
|         complete_q->Push(-1UL); | ||||
|         return; | ||||
|       } | ||||
|       auto &outputs = op_to_run->Outputs(); | ||||
|       op_to_run = nullptr; | ||||
|       for (auto &output : outputs) { | ||||
|         for (auto &pending_op : output->PendingOps()) { | ||||
|           std::atomic<int> &deps = op_deps->at(pending_op); | ||||
|           if (deps.fetch_sub(1) == 1) {  // pending_op ready
 | ||||
|             if (op_to_run == nullptr) { | ||||
|               op_to_run = pending_op; | ||||
|             } else { | ||||
|               this->RunOpAsync(op_deps, pending_op, complete_q); | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     --remaining_; | ||||
|     complete_q->Push(complete); | ||||
|   }); | ||||
| } | ||||
| void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { | ||||
|   atomic_op_deps_ = pool_.enqueue([&] { | ||||
|     std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps = | ||||
|         new std::unordered_map<OpHandleBase *, std::atomic<int>>; | ||||
|     for (auto &pair : op_deps_) { | ||||
|       (*op_deps)[pair.first] = pair.second; | ||||
|     } | ||||
|     return std::unique_ptr< | ||||
|         std::unordered_map<OpHandleBase *, std::atomic<int>>>(op_deps); | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } | ||||
| }  // namespace details
 | ||||
| }  // namespace framework
 | ||||
| }  // namespace paddle
 | ||||
| @ -0,0 +1,64 @@ | ||||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #pragma once | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include "ThreadPool.h" | ||||
| #include "paddle/fluid/framework/blocking_queue.h" | ||||
| #include "paddle/fluid/framework/details/exception_holder.h" | ||||
| #include "paddle/fluid/framework/details/execution_strategy.h" | ||||
| #include "paddle/fluid/framework/details/ssa_graph_executor.h" | ||||
| 
 | ||||
| namespace paddle { | ||||
| namespace framework { | ||||
| class Scope; | ||||
| namespace details { | ||||
| 
 | ||||
| class OpHandleBase; | ||||
| class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { | ||||
|  public: | ||||
|   FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, | ||||
|                                const std::vector<Scope *> &local_scopes, | ||||
|                                const std::vector<platform::Place> &places, | ||||
|                                std::unique_ptr<ir::Graph> &&graph); | ||||
|   FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; | ||||
|   const ir::Graph &Graph() const override; | ||||
| 
 | ||||
|  private: | ||||
|   ExecutionStrategy strategy_; | ||||
|   std::vector<Scope *> local_scopes_; | ||||
|   std::vector<platform::Place> places_; | ||||
|   std::unique_ptr<ir::Graph> graph_; | ||||
| 
 | ||||
|   std::unordered_map<OpHandleBase *, int> op_deps_; | ||||
|   std::vector<OpHandleBase *> bootstrap_ops_; | ||||
| 
 | ||||
|   ::ThreadPool pool_; | ||||
|   platform::DeviceContextPool fetch_ctxs_; | ||||
|   std::atomic<int> remaining_; | ||||
| 
 | ||||
|   void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, | ||||
|                   OpHandleBase *op, BlockingQueue<size_t> *complete_q); | ||||
| 
 | ||||
|   void PrepareAtomicOpDeps(); | ||||
| 
 | ||||
|   std::future< | ||||
|       std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>> | ||||
|       atomic_op_deps_; | ||||
|   ExceptionHolder exception_; | ||||
| }; | ||||
| }  // namespace details
 | ||||
| }  // namespace framework
 | ||||
| }  // namespace paddle
 | ||||
| @ -1,4 +1,4 @@ | ||||
| nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto) | ||||
| nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context) | ||||
| nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) | ||||
| nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine) | ||||
| add_subdirectory(convert) | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue