!669 change l2 buffer for mult_batch

From: @jiming6
Reviewed-by: @liujunzhu,@xchu42
Signed-off-by: @wqtshg
pull/669/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5d743d9ea2

@ -48,26 +48,41 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap
}
}
bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) {
bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph) {
if (comp_graph == nullptr) {
return false;
}
std::set<int64_t> stream_set;
std::set<std::string> label_set;
for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue);
int64_t stream_id = cur_node->GetOpDesc()->GetStreamId();
if (stream_id == kInvalidStream) {
continue;
}
GELOGD("Node %s in subgraph %s stream id is: %ld, node num: %zu", cur_node->GetName().c_str(),
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
stream_set.insert(stream_id);
std::string batch_label;
if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
label_set.insert(batch_label);
} else {
GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(),
cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id);
continue;
}
GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(),
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
}
if (stream_set.size() > 1) {
GELOGI("Nodes of graph: %s have different stream id, node num: %zu, different stream num: %zu.",
if (stream_set.size() > 1 || label_set.size() > 1) {
GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.",
comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size());
return false;
}
if (!label_set.empty()) {
(void)AttrUtils::SetStr(comp_graph, ATTR_NAME_BATCH_LABEL, *label_set.begin());
}
return true;
}
@ -99,8 +114,8 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
continue;
}
if (!IsSameStreamId(subgraph)) {
GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str());
if (!IsSameStreamIdOrBatchLabel(subgraph)) {
GELOGI("There are more than one stream or batch_label in subgraph %s", subgraph->GetName().c_str());
continue;
}
OpDescPtr op_desc = nodes.at(0)->GetOpDesc();
@ -112,9 +127,11 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
return FAILED;
}
run_context.stream = run_context.graphStreamList[stream_id];
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.",
subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)));
std::string batch_label;
(void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label);
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, "
"batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str());
for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) {
GE_CHECK_NOTNULL(*iter);
Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context);

@ -41,7 +41,7 @@ class StreamGraphOptimizer {
private:
void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map);
bool IsSameStreamId(const ComputeGraphPtr &comp_graph);
bool IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph);
};
} // namespace ge
#endif // GE_GRAPH_BUILD_OPTIMIZE_STREAM_GRAPH_H_

@ -567,7 +567,7 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_
continue;
}
string op_type = op_desc->GetType();
if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0)) {
if (!op_desc->GetSubgraphInstanceNames().empty() || separator_types.count(op_type) != 0) {
continuous_op_lists.emplace_back(vector<OpDescPtr>());
} else {
continuous_op_lists.back().emplace_back(op_desc);

@ -1407,11 +1407,13 @@ Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
}
Status ProcessMultiBatch(ComputeGraphPtr &graph) {
const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE");
if (multi_batch_with_case != nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
return pass_manager.Run(graph);
if (GetLocalOmgContext().dynamic_node_type.empty()) {
const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN");
if (multi_batch_with_switchn == nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
return pass_manager.Run(graph);
}
}
if (!GetLocalOmgContext().need_multi_batch) {
GELOGI("No need to process_multi for no_train graph.");

Loading…
Cancel
Save