|
|
@ -14,7 +14,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
#include <ThreadPool.h>
|
|
|
|
#include <ThreadPool.h>
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include "paddle/fluid/framework/blocking_queue.h"
|
|
|
|
#include "paddle/fluid/framework/blocking_queue.h"
|
|
|
|
#include "paddle/fluid/framework/details/exception_holder.h"
|
|
|
|
#include "paddle/fluid/framework/details/exception_holder.h"
|
|
|
@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
const ir::Graph &Graph() const override;
|
|
|
|
const ir::Graph &Graph() const override;
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
|
|
|
|
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
|
|
|
|
|
|
|
|
// be destroyed first.
|
|
|
|
ExecutionStrategy strategy_;
|
|
|
|
ExecutionStrategy strategy_;
|
|
|
|
std::vector<Scope *> local_scopes_;
|
|
|
|
std::vector<Scope *> local_scopes_;
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
|
|
|
|
std::unordered_map<OpHandleBase *, int> op_deps_;
|
|
|
|
std::unordered_map<OpHandleBase *, int> op_deps_;
|
|
|
|
std::vector<OpHandleBase *> bootstrap_ops_;
|
|
|
|
std::vector<OpHandleBase *> bootstrap_ops_;
|
|
|
|
|
|
|
|
|
|
|
|
::ThreadPool pool_;
|
|
|
|
|
|
|
|
::ThreadPool prepare_pool_;
|
|
|
|
|
|
|
|
platform::DeviceContextPool fetch_ctxs_;
|
|
|
|
platform::DeviceContextPool fetch_ctxs_;
|
|
|
|
std::atomic<int> remaining_;
|
|
|
|
std::atomic<int> remaining_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::future<
|
|
|
|
|
|
|
|
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
|
|
|
|
|
|
|
|
atomic_op_deps_;
|
|
|
|
|
|
|
|
ExceptionHolder exception_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
::ThreadPool pool_;
|
|
|
|
|
|
|
|
::ThreadPool prepare_pool_;
|
|
|
|
|
|
|
|
|
|
|
|
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
|
|
|
|
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
|
|
|
|
OpHandleBase *op,
|
|
|
|
OpHandleBase *op,
|
|
|
|
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
|
|
|
|
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
|
|
|
|
|
|
|
|
|
|
|
|
void PrepareAtomicOpDeps();
|
|
|
|
void PrepareAtomicOpDeps();
|
|
|
|
|
|
|
|
|
|
|
|
std::future<
|
|
|
|
|
|
|
|
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
|
|
|
|
|
|
|
|
atomic_op_deps_;
|
|
|
|
|
|
|
|
ExceptionHolder exception_;
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
} // namespace details
|
|
|
|
} // namespace details
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|