|
|
|
@ -349,11 +349,10 @@ void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotN
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
std::set<KernelGraphPtr> memo;
|
|
|
|
|
(void)RecurseGraph(nullptr, nullptr, root_graph, NOT_NULL(&memo));
|
|
|
|
|
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
|
|
|
|
|
NotNull<KernelGraphPtr> graph,
|
|
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
|
|
|
|
|
NotNull<std::set<KernelGraphPtr> *> memo) {
|
|
|
|
|
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
|
|
|
|
|
auto print_vector = [&](std::vector<CNodePtr> vec) -> void {
|
|
|
|
@ -366,52 +365,38 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
memo->insert(graph.get());
|
|
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
|
|
|
|
|
graph->SetExecOrderByDefault();
|
|
|
|
|
|
|
|
|
|
const std::vector<CNodePtr> &cnodes = graph->execution_order();
|
|
|
|
|
std::map<uint32_t, CNodePtr> label_map;
|
|
|
|
|
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
|
|
|
|
|
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes);
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> execution_order;
|
|
|
|
|
uint32_t child_order_index = 0;
|
|
|
|
|
|
|
|
|
|
for (auto &node : cnodes) {
|
|
|
|
|
execution_order.push_back(node);
|
|
|
|
|
if (node == graph->get_end_goto()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto label_iter =
|
|
|
|
|
std::find_if(label_map.begin(), label_map.end(),
|
|
|
|
|
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; });
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
|
|
|
|
if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) {
|
|
|
|
|
if (!CheckLabelIndex(child_order_index, 0, node, graph)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail";
|
|
|
|
|
}
|
|
|
|
|
auto child_graph = child_graph_order[label_iter->first];
|
|
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++];
|
|
|
|
|
if (child_graph == graph->parent_graph()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::map<uint32_t, CNodePtr> child_label_map;
|
|
|
|
|
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order());
|
|
|
|
|
auto child_execution_order =
|
|
|
|
|
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo);
|
|
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
|
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
|
|
|
|
std::vector<uint32_t> label_list = label_switch_map.find(node)->second;
|
|
|
|
|
std::reverse(label_list.begin(), label_list.end());
|
|
|
|
|
for (size_t i = 0; i < label_list.size(); ++i) {
|
|
|
|
|
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) {
|
|
|
|
|
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node);
|
|
|
|
|
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
|
|
|
|
|
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check label index fail";
|
|
|
|
|
}
|
|
|
|
|
auto child_graph = child_graph_order[label_iter->first + i];
|
|
|
|
|
auto child_graph = graph->child_graph_order()[child_order_index++];
|
|
|
|
|
if (child_graph == graph->parent_graph()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::map<uint32_t, CNodePtr> child_label_map;
|
|
|
|
|
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order());
|
|
|
|
|
auto child_execution_order =
|
|
|
|
|
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo);
|
|
|
|
|
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
|
|
|
|
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -421,6 +406,15 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
|
|
|
|
|
return execution_order;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) {
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
|
|
|
|
|
}
|
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
|
|
|
|
|
NotNull<KernelGraphPtr> graph) {
|
|
|
|
|
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
|
|
|
|
@ -458,31 +452,6 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> AscendControlParser::GetLabelNode(
|
|
|
|
|
const std::vector<CNodePtr> &nodes) {
|
|
|
|
|
std::map<uint32_t, CNodePtr> label_map;
|
|
|
|
|
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
|
|
|
|
|
// record child graph
|
|
|
|
|
uint32_t index = 0;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
|
|
|
|
label_map[index++] = node;
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
|
|
|
|
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
|
|
|
|
|
}
|
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
|
|
|
|
|
label_switch_map.insert({node, label_list});
|
|
|
|
|
for (size_t i = 0; i < label_list.size(); ++i) {
|
|
|
|
|
label_map[index++] = node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return {label_map, label_switch_map};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
|
|
|
|
|
MS_LOG(INFO) << "graph id:" << kg->graph_id();
|
|
|
|
|
kg->SetExecOrderByDefault();
|
|
|
|
|