|
|
|
@ -19,6 +19,7 @@
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include "session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "utils/union_find_set.h"
|
|
|
|
|
#include "device/ascend/ascend_label_assign.h"
|
|
|
|
|
|
|
|
|
|
static constexpr size_t kCNodePrim = 0;
|
|
|
|
|
static constexpr size_t kCNodeCallArg = 1;
|
|
|
|
@ -35,17 +36,25 @@ namespace mindspore {
|
|
|
|
|
namespace session {
|
|
|
|
|
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
|
|
|
|
|
auto &nodes = parent_graph->execution_order();
|
|
|
|
|
CNodePtr last_jump_node = nullptr;
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
|
|
|
|
|
if (child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
|
|
|
|
return node;
|
|
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
|
|
|
|
|
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
|
|
|
|
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
|
|
|
|
|
}
|
|
|
|
|
last_jump_node = node;
|
|
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
|
|
|
|
|
if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
|
|
|
|
child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) {
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
last_jump_node = node;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (last_jump_node == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
|
|
|
|
}
|
|
|
|
|
return last_jump_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
|
|
|
@ -90,6 +99,9 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
|
|
|
|
|
if (!arg->isa<Parameter>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
union_find_set->Union(arg, para);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -133,13 +145,8 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
|
|
|
|
|
auto parameter_reuse_sets = parameter_set->GetSets();
|
|
|
|
|
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
|
|
|
|
|
if (parameter_reuse_set.size() <= 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key,
|
|
|
|
|
const std::set<AnfNodePtr> ¶meter_reuse_set) {
|
|
|
|
|
AnfNodePtr main_parameter = key;
|
|
|
|
|
std::set<AnfNodePtr> root_inputs_set;
|
|
|
|
|
const auto &root_inputs_vector = root_kg->inputs();
|
|
|
|
@ -150,7 +157,16 @@ static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return main_parameter;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
|
|
|
|
|
auto parameter_reuse_sets = parameter_set->GetSets();
|
|
|
|
|
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
|
|
|
|
|
if (parameter_reuse_set.size() <= 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set);
|
|
|
|
|
std::set<KernelGraphPtr> memo;
|
|
|
|
|
RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo));
|
|
|
|
|
}
|
|
|
|
@ -168,6 +184,7 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
|
|
|
|
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
|
|
|
|
std::set<KernelGraphPtr> memo;
|
|
|
|
|
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
|
|
|
|
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
|
|
|
|
|
std::map<uint32_t, KernelGraphPtr> graph_id_map;
|
|
|
|
|
for (auto &g : memo) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(g);
|
|
|
|
@ -177,12 +194,13 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
|
|
|
|
}
|
|
|
|
|
graph_id_map[g->graph_id()] = g;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Insert Assign
|
|
|
|
|
ChildGraphDataAssign(graph_id_map);
|
|
|
|
|
// Make UnionFindSet
|
|
|
|
|
UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
|
|
|
|
|
// Reuse Parameter
|
|
|
|
|
ReuseParameter(kg, NOT_NULL(¶meter_set));
|
|
|
|
|
// Insert Assign
|
|
|
|
|
ChildGraphDataAssign(graph_id_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
@ -193,6 +211,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
|
|
|
|
|
for (auto &iter : graph_id_map) {
|
|
|
|
|
auto &kg = iter.second;
|
|
|
|
|
MS_LOG(INFO) << "Data assign graph:" << kg->graph_id();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kg);
|
|
|
|
|
std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
|
|
|
|
|
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
|
|
|
@ -206,8 +225,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|
|
|
|
} else {
|
|
|
|
|
memo.emplace(parameter, arg);
|
|
|
|
|
}
|
|
|
|
|
if (arg->isa<Parameter>()) {
|
|
|
|
|
auto unreuse_args_map = kg->unreuse_args();
|
|
|
|
|
auto unreuse_arg_iter = unreuse_args_map.find(arg);
|
|
|
|
|
if (unreuse_arg_iter == unreuse_args_map.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(arg);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
if (!arg->isa<Parameter>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << ".";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
|
|
|
|
<< ", arg:" << arg->DebugString();
|
|
|
|
|
continue;
|
|
|
|
@ -220,6 +245,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|
|
|
|
NOT_NULL(parameter));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
kg->SetExecOrderByDefault();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -353,7 +379,6 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
|
|
|
|
// 5 recurse sub graph
|
|
|
|
|
CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo);
|
|
|
|
|
new_inputs.push_back(sub_label);
|
|
|
|
|
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
|
|
|
|
|
cur_node->set_inputs(new_inputs);
|
|
|
|
|
cur_node->set_abstract(nullptr);
|
|
|
|
|
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
|
|
|
@ -394,7 +419,6 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|
|
|
|
}
|
|
|
|
|
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
|
|
|
|
|
|
|
|
|
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
|
|
|
|
|
cur_node->set_inputs(new_switch_inputs);
|
|
|
|
|
cur_node->set_abstract(nullptr);
|
|
|
|
|
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
|
|
|
|
@ -477,6 +501,16 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
|
|
|
|
|
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
|
|
|
|
if (assign_node != nullptr) {
|
|
|
|
|
auto jump_node = GetJumpNode(from_graph, to_graph);
|
|
|
|
|
const auto &from_graph_exe_order = from_graph->execution_order();
|
|
|
|
|
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
|
|
|
|
|
if (jump_node_iter == from_graph_exe_order.end()) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(jump_node);
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id();
|
|
|
|
|
}
|
|
|
|
|
// insert assign between jump_node -1 and jump_node
|
|
|
|
|
if (jump_node_iter != from_graph_exe_order.begin()) {
|
|
|
|
|
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
|
|
|
|
|
}
|
|
|
|
|
if (jump_node != nullptr) {
|
|
|
|
|
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
|
|
|
|
|
}
|
|
|
|
@ -501,8 +535,6 @@ AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg,
|
|
|
|
|
auto assign_node = kg->NewCNode(inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(assign_node);
|
|
|
|
|
assign_node->set_abstract(to->abstract());
|
|
|
|
|
// append the assign at the end of from graph
|
|
|
|
|
InsertDependToGraph(kg, NOT_NULL(assign_node));
|
|
|
|
|
return assign_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -527,7 +559,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> execution_order;
|
|
|
|
|
uint32_t child_order_index = 0;
|
|
|
|
|
|
|
|
|
|
for (auto &node : cnodes) {
|
|
|
|
|
execution_order.push_back(node);
|
|
|
|
|
if (node == graph->get_end_goto()) {
|
|
|
|
|