From d3515b16248b43f5a8285428eedcc16cabdd3125 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 14 Jan 2021 10:24:40 +0800 Subject: [PATCH 1/2] change mult batch to switchn --- ge/graph/preprocess/multi_batch_copy_graph.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 16535a59..60d19b14 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1692,13 +1692,11 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - if (GetLocalOmgContext().dynamic_node_type.empty()) { - const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); - if (multi_batch_with_switchn == nullptr) { - PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - return pass_manager.Run(graph); - } + const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); + if (multi_batch_with_case == nullptr) { + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + return pass_manager.Run(graph); } if (!GetLocalOmgContext().need_multi_batch) { From 81cd9527aac4574144721f739d2bbc7c08385ed6 Mon Sep 17 00:00:00 2001 From: wjm Date: Thu, 14 Jan 2021 11:00:24 +0800 Subject: [PATCH 2/2] fix error --- ge/graph/preprocess/multi_batch_copy_graph.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 60d19b14..aa1812ba 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1693,7 +1693,7 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { Status ProcessMultiBatch(ComputeGraphPtr &graph) { const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); - if (multi_batch_with_case == nullptr) { + if (multi_batch_with_case != nullptr) { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); return pass_manager.Run(graph);