|
|
@ -35,18 +35,18 @@ using details::VarHandleBase;
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit ParallelExecutorPrivate(size_t num_threads)
|
|
|
|
explicit ParallelExecutorPrivate(size_t num_threads,
|
|
|
|
: pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
const std::vector<platform::Place> &places)
|
|
|
|
|
|
|
|
: places_(places),
|
|
|
|
|
|
|
|
fetch_dev_ctxs_(places),
|
|
|
|
|
|
|
|
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
|
|
|
|
|
platform::DeviceContextPool fetch_dev_ctxs_;
|
|
|
|
std::vector<Scope *> local_scopes_;
|
|
|
|
std::vector<Scope *> local_scopes_;
|
|
|
|
Scope *global_scope_;
|
|
|
|
Scope *global_scope_;
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
|
|
|
|
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
|
|
|
|
std::unordered_map<platform::Place, platform::DeviceContext *,
|
|
|
|
|
|
|
|
platform::PlaceHash>
|
|
|
|
|
|
|
|
fetch_dev_ctxs_;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
platform::Place main_place_;
|
|
|
|
platform::Place main_place_;
|
|
|
|
|
|
|
|
|
|
|
@ -219,20 +219,9 @@ ParallelExecutor::ParallelExecutor(
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const std::unordered_set<std::string> ¶ms,
|
|
|
|
const ProgramDesc &startup_program, const ProgramDesc &main_program,
|
|
|
|
const ProgramDesc &startup_program, const ProgramDesc &main_program,
|
|
|
|
const std::string &loss_var_name, Scope *scope)
|
|
|
|
const std::string &loss_var_name, Scope *scope)
|
|
|
|
: member_(new ParallelExecutorPrivate(num_threads)) {
|
|
|
|
: member_(new ParallelExecutorPrivate(num_threads, places)) {
|
|
|
|
member_->places_ = places;
|
|
|
|
|
|
|
|
member_->global_scope_ = scope;
|
|
|
|
member_->global_scope_ = scope;
|
|
|
|
|
|
|
|
|
|
|
|
if (platform::is_cpu_place(places[0])) {
|
|
|
|
|
|
|
|
member_->fetch_dev_ctxs_[places[0]] = const_cast<platform::DeviceContext *>(
|
|
|
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(places[0]));
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
for (auto &p : member_->places_) {
|
|
|
|
|
|
|
|
member_->fetch_dev_ctxs_[p] =
|
|
|
|
|
|
|
|
new platform::CUDADeviceContext(boost::get<platform::CUDAPlace>(p));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Step 1. RunStartupProgram and Bcast the params to devs.
|
|
|
|
// Step 1. RunStartupProgram and Bcast the params to devs.
|
|
|
|
Executor exe(places[0]);
|
|
|
|
Executor exe(places[0]);
|
|
|
|
exe.Run(startup_program, scope, 0);
|
|
|
|
exe.Run(startup_program, scope, 0);
|
|
|
@ -509,7 +498,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
|
|
|
|
|
|
|
// FIXME: Use new device context
|
|
|
|
// FIXME: Use new device context
|
|
|
|
for (auto &p : member_->places_) {
|
|
|
|
for (auto &p : member_->places_) {
|
|
|
|
op->dev_ctx_[p] = member_->fetch_dev_ctxs_[p];
|
|
|
|
op->dev_ctx_[p] = member_->fetch_dev_ctxs_.Get(p);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto *var : vars) {
|
|
|
|
for (auto *var : vars) {
|
|
|
|