|
|
|
@ -30,9 +30,6 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace details {
|
|
|
|
|
|
|
|
|
|
static std::atomic<unsigned int> exec_op_count_;
|
|
|
|
|
static std::atomic<int> error_state;
|
|
|
|
|
|
|
|
|
|
BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor(
|
|
|
|
|
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
|
|
|
|
|
const std::vector<Scope *> &local_exec_scopes,
|
|
|
|
@ -125,7 +122,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
|
|
|
|
|
for (auto cur_op : ready_fetch_ops) {
|
|
|
|
|
ready_ops->Push(cur_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Atomic variable, no need to lock
|
|
|
|
|
exec_op_count_ = 0;
|
|
|
|
|
|
|
|
|
|
platform::XPUPlace cur_place;
|
|
|
|
@ -134,9 +131,8 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
|
|
|
|
|
while (cur_count < op_deps_.size()) {
|
|
|
|
|
cur_count++;
|
|
|
|
|
auto cur_op = ready_ops->Pop();
|
|
|
|
|
// when execption, get cur_op == nullptr
|
|
|
|
|
if (cur_op == nullptr) {
|
|
|
|
|
// sleep a while to make sure worker thread quit
|
|
|
|
|
sleep(10);
|
|
|
|
|
exec_op_count_ = op_deps_.size();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
@ -151,14 +147,16 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
|
|
|
|
|
RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
while (exec_op_count_ < op_deps_.size()) {
|
|
|
|
|
{
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
|
|
|
cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
|
if (exception_.IsCaught()) {
|
|
|
|
|
ExecutionFinal(&fetch_ops);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Wait FetchOps.
|
|
|
|
|
ClearFetchOp(graph_, &fetch_ops);
|
|
|
|
|
return fetches;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -222,7 +220,8 @@ void BindThreadedSSAGraphExecutor::InsertFetchOps(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RunMultiDeviceOpAsync function is used for Communicated OPs
|
|
|
|
|
// like all_reduce\broadcast among multicards.
|
|
|
|
|
void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
|
|
|
|
|
OpHandleBase *op,
|
|
|
|
|
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
|
|
|
|
@ -256,10 +255,12 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
|
|
|
|
|
ready_ops->Push(nullptr);
|
|
|
|
|
exception_.Catch(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
// Atomic variable, no need to lock
|
|
|
|
|
exec_op_count_++;
|
|
|
|
|
cv_.notify_all();
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RunOpAsyncMainStream function is used for computed OPs
|
|
|
|
|
void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
|
|
|
|
|
OpHandleBase *op,
|
|
|
|
|
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
|
|
|
|
@ -285,7 +286,9 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
|
|
|
|
|
ready_ops->Push(nullptr);
|
|
|
|
|
exception_.Catch(std::current_exception());
|
|
|
|
|
}
|
|
|
|
|
// Atomic variable, no need to lock
|
|
|
|
|
exec_op_count_++;
|
|
|
|
|
cv_.notify_all();
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|