unfold repeated label in labelswitches

pull/14149/head
liangzelang 4 years ago
parent 8ef0e8a8b0
commit 4dab4521c3

@ -30,6 +30,7 @@
#include "debug/anf_ir_dump.h"
#include "pipeline/jit/base.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/kernel_select_ascend.h"
namespace mindspore {
namespace session {
@ -38,8 +39,8 @@ namespace {
// Pair of graph and its actual arguments.
using GraphArgPair = std::pair<KernelGraphPtr, std::vector<AnfNodePtr>>;
// We start label id from 1, and use 0 to indicate label not set.
constexpr uint32_t kNoLabel = 0;
// We start label id from 0, and use 0xFFFFFFFF to indicate label not set.
constexpr uint32_t kNoLabel = 0xFFFFFFFF;
// Primitive attribute for argument link assign.
const char LINK[] = "link";
@ -296,7 +297,7 @@ class AscendAutoMonadContext : public BaseContext {
ParameterPool param_pool_;
// Current label id.
uint32_t label_id_ = 1;
uint32_t label_id_ = 0;
};
//
@ -1052,6 +1053,7 @@ class ExecuteOrderGenerator {
GenerateExecuteOrder();
EraseParameter();
EraseLabel();
UnfoldRepeatedLabels();
}
private:
@ -1060,6 +1062,101 @@ class ExecuteOrderGenerator {
generator.GenerateExecuteOrder();
}
uint32_t FindMaxLabelId(const std::vector<CNodePtr> &nodes) {
uint32_t max_label = 0;
for (auto &node : nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
max_label = std::max(label_id, max_label);
}
}
return max_label;
}
void HandleLabelSwitch(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
std::multimap<uint32_t, uint32_t> *labels_multimap) {
bool is_new_labels = false;
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(node, kAttrLabelSwitchList);
std::vector<uint32_t> new_labels;
new_labels.reserve(label_list.size());
for (auto label_id : label_list) {
auto iter = std::find_if(labels->begin(), labels->end(), [label_id](auto id) { return id == label_id; });
// Use new label if find repeated label.
if (iter == labels->end()) {
new_labels.emplace_back(label_id);
continue;
}
new_labels.emplace_back(++max_label_);
labels_multimap->insert(std::pair<uint32_t, uint32_t>(*iter, max_label_));
is_new_labels = true;
}
labels->insert(labels->end(), new_labels.begin(), new_labels.end());
switch_labels->insert(switch_labels->end(), new_labels.begin(), new_labels.end());
if (is_new_labels) {
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue(new_labels), node);
}
}
void HandleLabelGoto(const AnfNodePtr &node, std::vector<uint32_t> *labels, std::vector<uint32_t> *switch_labels,
std::multimap<uint32_t, uint32_t> *labels_multimap) {
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
auto iter = std::find(switch_labels->begin(), switch_labels->end(), label_id);
if (iter == switch_labels->end()) {
labels->emplace_back(label_id);
return;
}
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(++max_label_), node);
labels_multimap->insert(std::pair<uint32_t, uint32_t>(*iter, max_label_));
labels->emplace_back(max_label_);
}
// Unfold Repeated Labels, avoid same label in labelswitches.
void UnfoldRepeatedLabels() {
auto nodes = graph_->execution_order();
std::vector<uint32_t> labels;
std::vector<uint32_t> switch_labels;
std::multimap<uint32_t, uint32_t> labels_multimap;
max_label_ = FindMaxLabelId(nodes);
for (auto &node : nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
HandleLabelSwitch(node, &labels, &switch_labels, &labels_multimap);
continue;
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
HandleLabelGoto(node, &labels, &switch_labels, &labels_multimap);
continue;
}
}
InsertLabelSet(&nodes, labels_multimap);
graph_->set_label_num(max_label_ + 1);
graph_->set_execution_order(nodes);
}
void InsertLabelSet(std::vector<CNodePtr> *nodes, const std::multimap<uint32_t, uint32_t> &labels_multimap) {
for (auto labels : labels_multimap) {
auto old_label = labels.first;
auto new_label = labels.second;
auto iter = std::find_if(nodes->begin(), nodes->end(), [old_label](auto node) {
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSet)) {
return false;
}
auto label_id = AnfAlgo::GetNodeAttr<uint32_t>(node, kAttrLabelIndex);
return label_id == old_label;
});
if (iter == nodes->end()) {
MS_LOG(EXCEPTION) << "Not found labelset:" << old_label;
}
auto label_set = NewValueNode(std::make_shared<Primitive>(prim::kPrimLabelSet->name()));
auto cnode = graph_->NewCNode({label_set});
AnfAlgo::CopyNodeAttrs(*iter, cnode);
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(new_label), cnode);
auto monad = graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad);
cnode->set_abstract(monad->abstract());
device::ascend::SelectKernelInfo(cnode);
nodes->insert(iter, cnode);
}
}
void AppendGraphOrder(std::vector<CNodePtr> *execution_order, const KernelGraphPtr &graph) {
auto &order = graph->execution_order();
execution_order->insert(execution_order->end(), order.begin(), order.end());
@ -1343,6 +1440,7 @@ class ExecuteOrderGenerator {
Context &context_;
const KernelGraphPtr graph_;
uint32_t max_label_ = 0;
};
} // namespace
@ -1353,7 +1451,7 @@ void AscendAutoMonad::Run() {
AscendAutoMonadContext context(kg);
CallInfoFinder::Run(&context);
AscendAutoMonadConverter::Run(&context);
kernel_graph_->set_label_num(context.CurrentLabel());
kernel_graph_->set_label_num(context.CurrentLabel() + 1);
MS_LOG(DEBUG) << "Ascend auto-monad finish.";
DumpGraphForDebug(kernel_graph_);
}

Loading…
Cancel
Save