|
|
@ -180,8 +180,8 @@ std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
|
|
|
return std::vector<AnfNodePtr>(1, graph_output);
|
|
|
|
return std::vector<AnfNodePtr>(1, graph_output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
|
|
|
void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
|
|
|
|
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first) {
|
|
|
|
MS_EXCEPTION_IF_NULL(visit_queue);
|
|
|
|
MS_EXCEPTION_IF_NULL(visit_queue);
|
|
|
|
MS_EXCEPTION_IF_NULL(visited_nodes);
|
|
|
|
MS_EXCEPTION_IF_NULL(visited_nodes);
|
|
|
|
auto it = node_output_edges_.find(node);
|
|
|
|
auto it = node_output_edges_.find(node);
|
|
|
@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
|
|
|
|
while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
|
|
|
|
// seed nodes first, then delay comm nodes
|
|
|
|
// seed nodes first, then delay comm nodes
|
|
|
|
if (seed_nodes.empty()) {
|
|
|
|
if (seed_nodes.empty()) {
|
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
|
|
|
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
|
|
|
delay_comm_stack.pop();
|
|
|
|
delay_comm_stack.pop();
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
zero_input_nodes.push(seed_nodes.front());
|
|
|
|
zero_input_nodes.push(seed_nodes.front());
|
|
|
@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (optimize_comm) {
|
|
|
|
if (optimize_comm) {
|
|
|
|
while (!delay_comm_stack.empty()) {
|
|
|
|
while (!delay_comm_stack.empty()) {
|
|
|
|
VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
|
|
|
EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false);
|
|
|
|
delay_comm_stack.pop();
|
|
|
|
delay_comm_stack.pop();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
delay_comm_stack.push(node);
|
|
|
|
delay_comm_stack.push(node);
|
|
|
|
} else if (is_fused_comm) {
|
|
|
|
} else if (is_fused_comm) {
|
|
|
|
delay_comm_stack.push(node);
|
|
|
|
delay_comm_stack.push(node);
|
|
|
|
} else if (is_communication_descendant) {
|
|
|
|
} else if (is_communication_descendant) {
|
|
|
|
VisitNodeDescendants(node, &communication_descendants, &visited_nodes);
|
|
|
|
EnqueueActiveNodes(node, &communication_descendants, &visited_nodes);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes);
|
|
|
|
EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|