From 4dab4521c36529b8a9e2d9c836a652bc2d8e8eaa Mon Sep 17 00:00:00 2001 From: liangzelang Date: Thu, 25 Mar 2021 19:49:23 +0800 Subject: [PATCH] unfold repeated label in labelswitches --- .../backend/session/ascend_auto_monad.cc | 106 +++++++++++++++++- 1 file changed, 102 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index 901ae00968..7607f254ba 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -30,6 +30,7 @@ #include "debug/anf_ir_dump.h" #include "pipeline/jit/base.h" #include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/ascend/kernel_select_ascend.h" namespace mindspore { namespace session { @@ -38,8 +39,8 @@ namespace { // Pair of graph and its actual arguments. using GraphArgPair = std::pair>; -// We start label id from 1, and use 0 to indicate label not set. -constexpr uint32_t kNoLabel = 0; +// We start label id from 0, and use 0xFFFFFFFF to indicate label not set. +constexpr uint32_t kNoLabel = 0xFFFFFFFF; // Primitive attribute for argument link assign. const char LINK[] = "link"; @@ -296,7 +297,7 @@ class AscendAutoMonadContext : public BaseContext { ParameterPool param_pool_; // Current label id. - uint32_t label_id_ = 1; + uint32_t label_id_ = 0; }; // @@ -1052,6 +1053,7 @@ class ExecuteOrderGenerator { GenerateExecuteOrder(); EraseParameter(); EraseLabel(); + UnfoldRepeatedLabels(); } private: @@ -1060,6 +1062,101 @@ class ExecuteOrderGenerator { generator.GenerateExecuteOrder(); } + uint32_t FindMaxLabelId(const std::vector &nodes) { + uint32_t max_label = 0; + for (auto &node : nodes) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) { + auto label_id = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + max_label = std::max(label_id, max_label); + } + } + return max_label; + } + + void HandleLabelSwitch(const AnfNodePtr &node, std::vector *labels, std::vector *switch_labels, + std::multimap *labels_multimap) { + bool is_new_labels = false; + auto label_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); + std::vector new_labels; + new_labels.reserve(label_list.size()); + for (auto label_id : label_list) { + auto iter = std::find_if(labels->begin(), labels->end(), [label_id](auto id) { return id == label_id; }); + // Use new label if find repeated label. + if (iter == labels->end()) { + new_labels.emplace_back(label_id); + continue; + } + new_labels.emplace_back(++max_label_); + labels_multimap->insert(std::pair(*iter, max_label_)); + is_new_labels = true; + } + labels->insert(labels->end(), new_labels.begin(), new_labels.end()); + switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end()); + if (is_new_labels) { + AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node); + } + } + + void HandleLabelGoto(const AnfNodePtr &node, std::vector *labels, std::vector *switch_labels, + std::multimap *labels_multimap) { + auto label_id = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id); + if (iter == switch_labels->end()) { + labels->emplace_back(label_id); + return; + } + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node); + labels_multimap->insert(std::pair(*iter, max_label_)); + labels->emplace_back(max_label_); + } + + // Unfold Repeated Labels, avoid same label in labelswitches. + void UnfoldRepeatedLabels() { + auto nodes = graph_->execution_order(); + std::vector labels; + std::vector switch_labels; + std::multimap labels_multimap; + max_label_ = FindMaxLabelId(nodes); + for (auto &node : nodes) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { + HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap); + continue; + } + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { + HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap); + continue; + } + } + InsertLabelSet(&nodes, labels_multimap); + graph_->set_label_num(max_label_ + 1); + graph_->set_execution_order(nodes); + } + + void InsertLabelSet(std::vector *nodes, const std::multimap &labels_multimap) { + for (auto labels : labels_multimap) { + auto old_label = labels.first; + auto new_label = labels.second; + auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) { + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) { + return false; + } + auto label_id = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + return label_id == old_label; + }); + if (iter == nodes->end()) { + MS_LOG(EXCEPTION) << "Not found labelset:" << old_label; + } + auto label_set = NewValueNode(std::make_shared(prim::kPrimLabelSet->name())); + auto cnode = graph_->NewCNode({label_set}); + AnfAlgo::CopyNodeAttrs(*iter, cnode); + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode); + auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad); + cnode->set_abstract(monad->abstract()); + device::ascend::SelectKernelInfo(cnode); + nodes->insert(iter, cnode); + } + } + void AppendGraphOrder(std::vector *execution_order, const KernelGraphPtr &graph) { auto &order = graph->execution_order(); execution_order->insert(execution_order->end(), order.begin(), order.end()); @@ -1343,6 +1440,7 @@ class ExecuteOrderGenerator { Context &context_; const KernelGraphPtr graph_; + uint32_t max_label_ = 0; }; } // namespace @@ -1353,7 +1451,7 @@ void AscendAutoMonad::Run() { AscendAutoMonadContext context(kg); CallInfoFinder::Run(&context); AscendAutoMonadConverter::Run(&context); - kernel_graph_->set_label_num(context.CurrentLabel()); + kernel_graph_->set_label_num(context.CurrentLabel() + 1); MS_LOG(DEBUG) << "Ascend auto-monad finish."; DumpGraphForDebug(kernel_graph_); }