[Kunlun] Add condition_variable and notify() in BindThreadedSSAGraphExecutor (#30586)

revert-31068-fix_conv3d_windows
liuyuhui 4 years ago committed by GitHub
parent ca33821475
commit e5b0d9e1fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,9 +30,6 @@ namespace paddle {
namespace framework {
namespace details {
static std::atomic<unsigned int> exec_op_count_;
static std::atomic<int> error_state;
BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
@ -125,7 +122,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
for (auto cur_op : ready_fetch_ops) {
ready_ops->Push(cur_op);
}
// Atomic variable, no need to lock
exec_op_count_ = 0;
platform::XPUPlace cur_place;
@ -134,9 +131,8 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
while (cur_count < op_deps_.size()) {
cur_count++;
auto cur_op = ready_ops->Pop();
// when execption, get cur_op == nullptr
if (cur_op == nullptr) {
// sleep a while to make sure worker thread quit
sleep(10);
exec_op_count_ = op_deps_.size();
break;
}
@ -151,14 +147,16 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index);
}
}
while (exec_op_count_ < op_deps_.size()) {
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); });
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
}
@ -222,7 +220,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps(
}
}
}
// RunMultiDeviceOpAsync function is used for Communicated OPs
// like all_reduce\broadcast among multicards.
void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
@ -256,10 +255,12 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
// Atomic variable, no need to lock
exec_op_count_++;
cv_.notify_all();
});
}
// RunOpAsyncMainStream function is used for computed OPs
void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
@ -285,7 +286,9 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
// Atomic variable, no need to lock
exec_op_count_++;
cv_.notify_all();
});
}

@ -14,7 +14,9 @@
#pragma once
#include <ThreadPool.h>
#include <condition_variable> // NOLINT
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
@ -76,6 +78,11 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool prepare_pool_;
::ThreadPool multi_device_op_pool_;
std::mutex mutex_;
std::condition_variable cv_;
std::atomic<unsigned int> exec_op_count_;
std::atomic<int> error_state;
void RunOpAsyncMainStream(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,

Loading…
Cancel
Save