From 42bf687a43482dd1554c2afdab72895cbe16ca35 Mon Sep 17 00:00:00 2001 From: wjm Date: Sat, 19 Dec 2020 14:43:29 +0800 Subject: [PATCH 1/3] mult batch --- ge/graph/build/stream_graph_optimizer.cc | 36 ++++++++++++++----- ge/graph/build/stream_graph_optimizer.h | 2 +- ge/graph/build/task_generator.cc | 2 +- .../load/new_model_manager/zero_copy_task.cc | 4 --- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ge/graph/build/stream_graph_optimizer.cc b/ge/graph/build/stream_graph_optimizer.cc index 2933d413..f86f846e 100644 --- a/ge/graph/build/stream_graph_optimizer.cc +++ b/ge/graph/build/stream_graph_optimizer.cc @@ -48,26 +48,42 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap } } -bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { +bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph) { if (comp_graph == nullptr) { return false; } std::set stream_set; + std::set label_set; for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) { GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue); int64_t stream_id = cur_node->GetOpDesc()->GetStreamId(); if (stream_id == kInvalidStream) { continue; } - GELOGD("Node %s in subgraph %s stream id is: %ld, node num: %zu", cur_node->GetName().c_str(), + stream_set.insert(stream_id); + + std::string batch_label; + if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { + label_set.insert(batch_label); + } else { + GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(), + cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id); + continue; + } + + GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(), comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize()); stream_set.insert(stream_id); } - if (stream_set.size() > 1) { - GELOGI("Nodes of graph: %s have different stream id, node num: %zu, different stream num: %zu.", + if (stream_set.size() > 1 || label_set.size() > 1) { + GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size()); return false; } + + if (!label_set.empty()) { + (void)AttrUtils::SetStr(comp_graph, ATTR_NAME_BATCH_LABEL, *label_set.begin()); + } return true; } @@ -99,8 +115,8 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com continue; } - if (!IsSameStreamId(subgraph)) { - GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); + if (!IsSameStreamIdOrBatchLabel(subgraph)) { + GELOGI("There are more than one stream or batch_label in subgraph %s", subgraph->GetName().c_str()); continue; } OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); @@ -112,9 +128,11 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com return FAILED; } run_context.stream = run_context.graphStreamList[stream_id]; - GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", - subgraph->GetName().c_str(), engine_name.c_str(), stream_id, - static_cast(reinterpret_cast(run_context.stream))); + std::string batch_label; + (void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label); + GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, " + "batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id, + static_cast(reinterpret_cast(run_context.stream)), batch_label.c_str()); for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { GE_CHECK_NOTNULL(*iter); Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); diff --git a/ge/graph/build/stream_graph_optimizer.h b/ge/graph/build/stream_graph_optimizer.h index b0eea135..d69fa7ba 100644 --- a/ge/graph/build/stream_graph_optimizer.h +++ b/ge/graph/build/stream_graph_optimizer.h @@ -41,7 +41,7 @@ class StreamGraphOptimizer { private: void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); - bool IsSameStreamId(const ComputeGraphPtr &comp_graph); + bool IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph); }; } // namespace ge #endif // GE_GRAPH_BUILD_OPTIMIZE_STREAM_GRAPH_H_ diff --git a/ge/graph/build/task_generator.cc b/ge/graph/build/task_generator.cc index b506f945..2089ad31 100755 --- a/ge/graph/build/task_generator.cc +++ b/ge/graph/build/task_generator.cc @@ -567,7 +567,7 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector &ops, bool is_ continue; } string op_type = op_desc->GetType(); - if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0)) { + if (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0) { continuous_op_lists.emplace_back(vector()); } else { continuous_op_lists.back().emplace_back(op_desc); diff --git a/ge/graph/load/new_model_manager/zero_copy_task.cc b/ge/graph/load/new_model_manager/zero_copy_task.cc index 2609cb4b..98dccb3c 100755 --- a/ge/graph/load/new_model_manager/zero_copy_task.cc +++ b/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -124,10 +124,6 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const ma auto &cur_pair = *iter; uint8_t *args_info = args_info_.data(); for (auto offset : cur_pair.second) { - if (!CheckDynamicBatch(batch_addrs, batch_label, reinterpret_cast(args_addr_ + offset))) { - continue; - } - auto dst_addr = static_cast(buffer_addr); GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p", name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr); From a658c30e40b42b07969659ceff472a899c0a5c35 Mon Sep 17 00:00:00 2001 From: wjm Date: Sat, 19 Dec 2020 18:17:07 +0800 Subject: [PATCH 2/3] mult batch --- ge/graph/build/stream_graph_optimizer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/ge/graph/build/stream_graph_optimizer.cc b/ge/graph/build/stream_graph_optimizer.cc index f86f846e..05049818 100644 --- a/ge/graph/build/stream_graph_optimizer.cc +++ b/ge/graph/build/stream_graph_optimizer.cc @@ -73,7 +73,6 @@ bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &com GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(), comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize()); - stream_set.insert(stream_id); } if (stream_set.size() > 1 || label_set.size() > 1) { GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", From add33b1b7228b486c1c0fe1e16a3d95f0e345a88 Mon Sep 17 00:00:00 2001 From: wjm Date: Mon, 21 Dec 2020 10:12:14 +0800 Subject: [PATCH 3/3] mult batch --- ge/graph/preprocess/multi_batch_copy_graph.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ge/graph/preprocess/multi_batch_copy_graph.cc b/ge/graph/preprocess/multi_batch_copy_graph.cc index 9ab74d70..a90f145e 100644 --- a/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -1407,11 +1407,13 @@ Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); - if (multi_batch_with_case != nullptr) { - PassManager pass_manager; - GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); - return pass_manager.Run(graph); + if (GetLocalOmgContext().dynamic_node_type.empty()) { + const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN"); + if (multi_batch_with_switchn == nullptr) { + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); + return pass_manager.Run(graph); + } } if (!GetLocalOmgContext().need_multi_batch) { GELOGI("No need to process_multi for no_train graph.");