|
|
|
@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace session {
|
|
|
|
|
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
|
|
|
|
|
auto &nodes = parent_graph->execution_order();
|
|
|
|
|
for (auto &node : nodes) {
|
|
|
|
|
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
|
|
|
|
return node;
|
|
|
|
|
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
|
|
|
|
|
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
|
|
|
|
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
|
|
|
|
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
|
|
|
|
if (memo->find(kg.get()) != memo->end()) {
|
|
|
|
@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|
|
|
|
if (target_graph_iter == graph_id_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
|
|
|
|
|
}
|
|
|
|
|
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
|
|
|
|
|
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
|
|
|
|
|
NOT_NULL(parameter));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
|
|
|
|
|
return {partial_cnode, branch_kg};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
|
|
|
|
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
|
|
|
|
|
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
|
|
|
|
|
NotNull<AnfNodePtr> to) {
|
|
|
|
|
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
|
|
|
|
|
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
|
|
|
|
@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
|
|
|
|
|
<< to_outputs.size() << "]";
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < from_outputs.size(); i++) {
|
|
|
|
|
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
|
|
|
|
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
|
|
|
|
if (assign_node != nullptr) {
|
|
|
|
|
auto jump_node = GetJumpNode(from_graph, to_graph);
|
|
|
|
|
if (jump_node != nullptr) {
|
|
|
|
|
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
|
|
|
|
NotNull<AnfNodePtr> to) {
|
|
|
|
|
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
|
|
|
|
NotNull<AnfNodePtr> to) {
|
|
|
|
|
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
|
|
|
|
|
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
|
|
|
|
|
return;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (from.get() == to.get()) {
|
|
|
|
|
return;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
|
|
|
|
|
<< to->DebugString();
|
|
|
|
@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|
|
|
|
assign_node->set_abstract(to->abstract());
|
|
|
|
|
// append the assign at the end of from graph
|
|
|
|
|
InsertDependToGraph(kg, NOT_NULL(assign_node));
|
|
|
|
|
return assign_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
|
|
|
|
|