|
|
@ -12,8 +12,8 @@
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// limitations under the License.
|
|
|
|
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
|
|
|
|
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
|
|
|
|
|
|
|
|
#include <deque>
|
|
|
|
#include <memory>
|
|
|
|
#include <memory>
|
|
|
|
#include <queue>
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <unordered_set>
|
|
|
|
#include <unordered_set>
|
|
|
@ -191,13 +191,13 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
|
|
|
|
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
|
|
|
|
const std::shared_ptr<BlockingQueue<size_t>> &complete_q) {
|
|
|
|
++remaining_;
|
|
|
|
++remaining_;
|
|
|
|
this->pool_.enqueue([=] {
|
|
|
|
this->pool_.enqueue([=] {
|
|
|
|
std::queue<OpHandleBase *> op_queue;
|
|
|
|
std::deque<OpHandleBase *> op_queue;
|
|
|
|
op_queue.push(op);
|
|
|
|
op_queue.push_front(op);
|
|
|
|
|
|
|
|
|
|
|
|
size_t complete = 0;
|
|
|
|
size_t complete = 0;
|
|
|
|
while (!op_queue.empty()) {
|
|
|
|
while (!op_queue.empty()) {
|
|
|
|
OpHandleBase *op_to_run = op_queue.front();
|
|
|
|
OpHandleBase *op_to_run = op_queue.back();
|
|
|
|
op_queue.pop();
|
|
|
|
op_queue.pop_back();
|
|
|
|
|
|
|
|
|
|
|
|
if (!RunOp(op_to_run, complete_q, &complete)) {
|
|
|
|
if (!RunOp(op_to_run, complete_q, &complete)) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
@ -213,7 +213,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
|
|
|
|
// NOTE(zjl): op with highest priority should run
|
|
|
|
// NOTE(zjl): op with highest priority should run
|
|
|
|
// first without switching to another thread.
|
|
|
|
// first without switching to another thread.
|
|
|
|
if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) {
|
|
|
|
if (pending_op->GetPriority() == OpHandleBase::Priority::kHighest) {
|
|
|
|
op_queue.push(pending_op);
|
|
|
|
op_queue.push_back(pending_op);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
if (op_to_run == nullptr) {
|
|
|
|
if (op_to_run == nullptr) {
|
|
|
|
op_to_run = pending_op;
|
|
|
|
op_to_run = pending_op;
|
|
|
@ -224,7 +224,9 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (op_to_run != nullptr) op_queue.push(op_to_run);
|
|
|
|
if (op_to_run != nullptr) {
|
|
|
|
|
|
|
|
op_queue.push_front(op_to_run);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
--remaining_;
|
|
|
|
--remaining_;
|
|
|
|
complete_q->Push(complete);
|
|
|
|
complete_q->Push(complete);
|
|
|
|