|
|
|
@ -1238,7 +1238,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs";
|
|
|
|
|
int32_t index = 0;
|
|
|
|
|
std::vector<KernelGraphPtr> child_graphs;
|
|
|
|
|
auto start_label = graph->get_start_label();
|
|
|
|
|
auto start_label_id = AnfAlgo::GetNodeAttr<uint32_t>(graph->get_start_label(), kAttrLabelIndex);
|
|
|
|
|
auto end_node = graph->get_end_goto();
|
|
|
|
|
ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0);
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
|
|
|
|
@ -1247,9 +1247,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
|
|
|
|
|
auto kg = graphs_[graph_id];
|
|
|
|
|
auto nodes = kg->execution_order();
|
|
|
|
|
for (uint32_t i = 0; i < nodes.size(); i++) {
|
|
|
|
|
if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName &&
|
|
|
|
|
(AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) ==
|
|
|
|
|
AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) {
|
|
|
|
|
if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) {
|
|
|
|
|
if (i < (nodes.size() - 1)) {
|
|
|
|
|
new_inputs.push_back(nodes[i + 1]);
|
|
|
|
|
} else {
|
|
|
|
|