|
|
|
@ -201,17 +201,21 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
|
std::queue<AnfNodePtr> zero_input_nodes;
|
|
|
|
|
UpdateNodeEdgeList(&zero_input_nodes);
|
|
|
|
|
std::queue<AnfNodePtr> seed_nodes;
|
|
|
|
|
UpdateNodeEdgeList(&seed_nodes);
|
|
|
|
|
execution_order_.clear();
|
|
|
|
|
std::unordered_set<AnfNodePtr> visited_nodes;
|
|
|
|
|
std::queue<AnfNodePtr> zero_input_nodes;
|
|
|
|
|
AnfNodePtr last_communication_node = nullptr;
|
|
|
|
|
std::queue<AnfNodePtr> communication_descendants;
|
|
|
|
|
while (!zero_input_nodes.empty() || last_communication_node != nullptr) {
|
|
|
|
|
while (!seed_nodes.empty() || last_communication_node != nullptr) {
|
|
|
|
|
// seed nodes first, then visit last all reduce node descendant
|
|
|
|
|
if (last_communication_node != nullptr) {
|
|
|
|
|
if (seed_nodes.empty()) {
|
|
|
|
|
VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes);
|
|
|
|
|
last_communication_node = nullptr;
|
|
|
|
|
} else {
|
|
|
|
|
zero_input_nodes.push(seed_nodes.front());
|
|
|
|
|
seed_nodes.pop();
|
|
|
|
|
}
|
|
|
|
|
// all reduce node descendant first, then common queue
|
|
|
|
|
while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
|
|
|
|
@ -900,11 +904,14 @@ void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
|
|
|
|
seed_nodes->push(node);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto cnode = dyn_cast<CNode>(node);
|
|
|
|
|
if (cnode == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto &input : cnode->inputs()) {
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|
// We push inputs from right to left, so that them can be evaluated from left to right.
|
|
|
|
|
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
|
|
|
|
|
auto &input = *iter;
|
|
|
|
|
PushNoVisitedNode(input, &que, &visited_nodes);
|
|
|
|
|
AddDependEdge(node, input, 1);
|
|
|
|
|
}
|
|
|
|
|