diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 5bbd2fb1..1c897214 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -185,13 +185,17 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { const auto &parent_graph = compute_graph->GetParentGraph(); GE_CHECK_NOTNULL(parent_graph); - for (const NodePtr &node : compute_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { - continue; - } + bool flag = false; + (void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag); + if (!flag) { + for (const NodePtr &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { + continue; + } - node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + } } return PostParseSubgraph(compute_graph, subgraph_name, parent_node); diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 87d9749a..496ad214 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -578,6 +578,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const subgraph->SetName("Batch_" + std::to_string(i)); subgraph->SetParentNode(case_node_); subgraph->SetParentGraph(graph); + (void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); @@ -599,55 +600,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const } } - return PostProcSubgraph(graph); -} - -/// -/// @ingroup ge -/// @brief Assign parent index for branches. -/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. -/// @return 0: SUCCESS / others: FAILED -/// -Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { - auto func_desc = case_node_->GetOpDesc(); - domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; - auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); - if (post_func == nullptr) { - GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), - case_node_->GetType().c_str()); - if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS || - parse_func_v2 == nullptr) { - GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(), - case_node_->GetType().c_str()); - return FAILED; - } - } - - for (const auto &name : func_desc->GetSubgraphInstanceNames()) { - const auto &subgraph = graph->GetSubgraph(name); - if (subgraph == nullptr) { - GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str()); - return FAILED; - } - - std::string subgraph_name; - GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name), - "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); - - auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); - Status ret = FAILED; - if (post_func != nullptr) { - ret = post_func(subgraph_name, graph); - } else if (parse_func_v2 != nullptr) { - ret = parse_func_v2(subgraph_name.c_str(), graph); - } - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), - case_node_->GetName().c_str(), case_node_->GetType().c_str()); - return FAILED; - } - } - return SUCCESS; } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 1155dfc8..5921970a 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -131,14 +131,6 @@ class MultiBatchClonePass : public GraphPass { /// Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); - /// - /// @ingroup ge - /// @brief Assign parent index for branches. - /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. - /// @return 0: SUCCESS / others: FAILED - /// - Status PostProcSubgraph(const ComputeGraphPtr &graph); - /// /// @ingroup ge /// @brief Remove subgraph supend output anchor. diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index c8880b2e..754df184 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -29,6 +29,7 @@ #include "framework/omg/omg_inner_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" +#include "graph/passes/data_pass.h" #include "graph/passes/multi_batch_clone_pass.h" #include "graph/passes/prune_pass.h" #include "graph/preprocess/multi_batch_options.h" @@ -1697,6 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (multi_batch_with_switchn == nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); return pass_manager.Run(graph); } }