From acb3e2f6dda27421d10d5f99a91249589dd8ec84 Mon Sep 17 00:00:00 2001 From: wjm Date: Wed, 23 Dec 2020 20:50:16 +0800 Subject: [PATCH 1/5] fix case plugin error --- ge/graph/passes/data_pass.cc | 16 +++--- ge/graph/passes/multi_batch_clone_pass.cc | 50 +------------------ ge/graph/passes/multi_batch_clone_pass.h | 8 --- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 + 4 files changed, 13 insertions(+), 63 deletions(-) diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 5bbd2fb1..1c897214 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -185,13 +185,17 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { const auto &parent_graph = compute_graph->GetParentGraph(); GE_CHECK_NOTNULL(parent_graph); - for (const NodePtr &node : compute_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { - continue; - } + bool flag = false; + (void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag); + if (!flag) { + for (const NodePtr &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { + continue; + } - node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + } } return PostParseSubgraph(compute_graph, subgraph_name, parent_node); diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 87d9749a..496ad214 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -578,6 +578,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const subgraph->SetName("Batch_" + std::to_string(i)); subgraph->SetParentNode(case_node_); subgraph->SetParentGraph(graph); + (void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); @@ -599,55 +600,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const } } - return PostProcSubgraph(graph); -} - -/// -/// @ingroup ge -/// @brief Assign parent index for branches. -/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. -/// @return 0: SUCCESS / others: FAILED -/// -Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { - auto func_desc = case_node_->GetOpDesc(); - domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; - auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); - if (post_func == nullptr) { - GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), - case_node_->GetType().c_str()); - if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS || - parse_func_v2 == nullptr) { - GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(), - case_node_->GetType().c_str()); - return FAILED; - } - } - - for (const auto &name : func_desc->GetSubgraphInstanceNames()) { - const auto &subgraph = graph->GetSubgraph(name); - if (subgraph == nullptr) { - GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str()); - return FAILED; - } - - std::string subgraph_name; - GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name), - "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); - - auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); - Status ret = FAILED; - if (post_func != nullptr) { - ret = post_func(subgraph_name, graph); - } else if (parse_func_v2 != nullptr) { - ret = parse_func_v2(subgraph_name.c_str(), graph); - } - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), - case_node_->GetName().c_str(), case_node_->GetType().c_str()); - return FAILED; - } - } - return SUCCESS; } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 1155dfc8..5921970a 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -131,14 +131,6 @@ class MultiBatchClonePass : public GraphPass { /// Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); - /// - /// @ingroup ge - /// @brief Assign parent index for branches. - /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. - /// @return 0: SUCCESS / others: FAILED - /// - Status PostProcSubgraph(const ComputeGraphPtr &graph); - /// /// @ingroup ge /// @brief Remove subgraph supend output anchor. diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index c8880b2e..754df184 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -29,6 +29,7 @@ #include "framework/omg/omg_inner_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" +#include "graph/passes/data_pass.h" #include "graph/passes/multi_batch_clone_pass.h" #include "graph/passes/prune_pass.h" #include "graph/preprocess/multi_batch_options.h" @@ -1697,6 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (multi_batch_with_switchn == nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); return pass_manager.Run(graph); } } From 4f97bc9abdfc2d997fbf9406b07034fb168c9f98 Mon Sep 17 00:00:00 2001 From: wjm Date: Wed, 23 Dec 2020 21:00:20 +0800 Subject: [PATCH 2/5] fix case plugin error --- ge/graph/passes/data_pass.cc | 2 +- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 1c897214..2a689cd5 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -194,7 +194,7 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { continue; } - node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); } } diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 754df184..4f7cd9fb 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1698,7 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (multi_batch_with_switchn == nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); + GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); // 子图分配parent index return pass_manager.Run(graph); } } From f91fdad16a52b541ed3d89e88661d0a88371c792 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 24 Dec 2020 10:42:58 +0800 Subject: [PATCH 3/5] case plugin --- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 4f7cd9fb..eae97b04 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1698,7 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (multi_batch_with_switchn == nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); // 子图分配parent index + GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); // set subgraph parent index return pass_manager.Run(graph); } } From 113b80da8b56099bf7cf88aa8370cb5246208831 Mon Sep 17 00:00:00 2001 From: wjm Date: Fri, 25 Dec 2020 11:38:22 +0800 Subject: [PATCH 4/5] remove data_pass --- ge/graph/passes/data_pass.cc | 16 +++---- ge/graph/passes/multi_batch_clone_pass.cc | 45 ++++++++++++++++--- ge/graph/passes/multi_batch_clone_pass.h | 14 ++++-- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 - 4 files changed, 57 insertions(+), 20 deletions(-) diff --git a/ge/graph/passes/data_pass.cc b/ge/graph/passes/data_pass.cc index 2a689cd5..5bbd2fb1 100644 --- a/ge/graph/passes/data_pass.cc +++ b/ge/graph/passes/data_pass.cc @@ -185,17 +185,13 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { const auto &parent_graph = compute_graph->GetParentGraph(); GE_CHECK_NOTNULL(parent_graph); - bool flag = false; - (void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag); - if (!flag) { - for (const NodePtr &node : compute_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { - continue; - } - - node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); + for (const NodePtr &node : compute_graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { + continue; } + + node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); } return PostParseSubgraph(compute_graph, subgraph_name, parent_node); diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 496ad214..37f4b637 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -503,12 +503,24 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { /// /// @ingroup ge -/// @brief Set shape to Data node in branch. -/// @param [in] const NodePtr &data: data in branch. +/// @brief Update Data node in Subgraph. +/// @param [in] const NodePtr &data: data in Subgraph. /// @param [in] size_t index: The batch index. /// @return 0: SUCCESS / others: FAILED /// -Status MultiBatchClonePass::UpdateShapeToData(const NodePtr &data, size_t index) { +Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { + int node_index = -1; + if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { + GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); + return FAILED; + } + + int parent_index = node_index + 1; + if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str()); + return FAILED; + } + auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); const auto &dims = data_shape.GetDims(); if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { @@ -581,13 +593,15 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const (void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); + GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), + "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); const string key_name = "branches" + std::to_string(i); op_desc->AddSubgraphName(key_name); op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); for (const auto &data : input_nodes) { - GE_CHK_STATUS_RET(UpdateShapeToData(data, i), "Update %s failed", subgraph->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); } } @@ -596,7 +610,28 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const const auto &op_desc = n->GetOpDesc(); op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0"); if (n->GetType() == DATA) { - GE_CHK_STATUS_RET(UpdateShapeToData(n, 0), "Update %s failed", branch->GetName().c_str()); + GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str()); + } + } + + return SUCCESS; +} + +/// +/// @ingroup ge +/// @brief Update output_node in Subgraph. +/// @param [in] const NodePtr &data: output_node in Subgraph. +/// @return 0: SUCCESS / others: FAILED +/// +Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { + const auto &op_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { + GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(tensor); + if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { + GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); + return FAILED; } } diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 5921970a..3dbf91db 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -105,12 +105,20 @@ class MultiBatchClonePass : public GraphPass { /// /// @ingroup ge - /// @brief Set shape to Data node in branch. - /// @param [in] const NodePtr &data: data in branch. + /// @brief Update Data node in Subgraph. + /// @param [in] const NodePtr &data: data in Subgraph. /// @param [in] size_t index: The batch index. /// @return 0: SUCCESS / others: FAILED /// - Status UpdateShapeToData(const NodePtr &data, size_t index); + Status UpdateSubgraphData(const NodePtr &data, size_t index); + + /// + /// @ingroup ge + /// @brief Update output_node in Subgraph. + /// @param [in] const NodePtr &data: output_node in Subgraph. + /// @return 0: SUCCESS / others: FAILED + /// + Status UpdateSubgraphOutput(const NodePtr &output_node); /// /// @ingroup ge diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index eae97b04..c8880b2e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -29,7 +29,6 @@ #include "framework/omg/omg_inner_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" -#include "graph/passes/data_pass.h" #include "graph/passes/multi_batch_clone_pass.h" #include "graph/passes/prune_pass.h" #include "graph/preprocess/multi_batch_options.h" @@ -1698,7 +1697,6 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { if (multi_batch_with_switchn == nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); // set subgraph parent index return pass_manager.Run(graph); } } From df8f69cb418c58dfb074fa78dbdadf4940a0ddfe Mon Sep 17 00:00:00 2001 From: wjm Date: Fri, 25 Dec 2020 11:44:11 +0800 Subject: [PATCH 5/5] fix --- ge/graph/passes/multi_batch_clone_pass.cc | 3 +-- ge/graph/passes/multi_batch_clone_pass.h | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ge/graph/passes/multi_batch_clone_pass.cc b/ge/graph/passes/multi_batch_clone_pass.cc index 37f4b637..872f94fb 100755 --- a/ge/graph/passes/multi_batch_clone_pass.cc +++ b/ge/graph/passes/multi_batch_clone_pass.cc @@ -590,7 +590,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const subgraph->SetName("Batch_" + std::to_string(i)); subgraph->SetParentNode(case_node_); subgraph->SetParentGraph(graph); - (void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); graph->AddSubgraph(subgraph->GetName(), subgraph); all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), @@ -620,7 +619,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const /// /// @ingroup ge /// @brief Update output_node in Subgraph. -/// @param [in] const NodePtr &data: output_node in Subgraph. +/// @param [in] const NodePtr &output_node: output_node in Subgraph. /// @return 0: SUCCESS / others: FAILED /// Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { diff --git a/ge/graph/passes/multi_batch_clone_pass.h b/ge/graph/passes/multi_batch_clone_pass.h index 3dbf91db..ee137b5a 100755 --- a/ge/graph/passes/multi_batch_clone_pass.h +++ b/ge/graph/passes/multi_batch_clone_pass.h @@ -115,7 +115,7 @@ class MultiBatchClonePass : public GraphPass { /// /// @ingroup ge /// @brief Update output_node in Subgraph. - /// @param [in] const NodePtr &data: output_node in Subgraph. + /// @param [in] const NodePtr &output_node: output_node in Subgraph. /// @return 0: SUCCESS / others: FAILED /// Status UpdateSubgraphOutput(const NodePtr &output_node);