|
|
|
@ -48,26 +48,41 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) {
|
|
|
|
|
bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph) {
|
|
|
|
|
if (comp_graph == nullptr) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::set<int64_t> stream_set;
|
|
|
|
|
std::set<std::string> label_set;
|
|
|
|
|
for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) {
|
|
|
|
|
GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue);
|
|
|
|
|
int64_t stream_id = cur_node->GetOpDesc()->GetStreamId();
|
|
|
|
|
if (stream_id == kInvalidStream) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Node %s in subgraph %s stream id is: %ld, node num: %zu", cur_node->GetName().c_str(),
|
|
|
|
|
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
|
|
|
|
|
stream_set.insert(stream_id);
|
|
|
|
|
|
|
|
|
|
std::string batch_label;
|
|
|
|
|
if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
|
|
|
|
|
label_set.insert(batch_label);
|
|
|
|
|
} else {
|
|
|
|
|
GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(),
|
|
|
|
|
cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(),
|
|
|
|
|
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize());
|
|
|
|
|
}
|
|
|
|
|
if (stream_set.size() > 1) {
|
|
|
|
|
GELOGI("Nodes of graph: %s have different stream id, node num: %zu, different stream num: %zu.",
|
|
|
|
|
if (stream_set.size() > 1 || label_set.size() > 1) {
|
|
|
|
|
GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.",
|
|
|
|
|
comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size());
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!label_set.empty()) {
|
|
|
|
|
(void)AttrUtils::SetStr(comp_graph, ATTR_NAME_BATCH_LABEL, *label_set.begin());
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -99,8 +114,8 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!IsSameStreamId(subgraph)) {
|
|
|
|
|
GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str());
|
|
|
|
|
if (!IsSameStreamIdOrBatchLabel(subgraph)) {
|
|
|
|
|
GELOGI("There are more than one stream or batch_label in subgraph %s", subgraph->GetName().c_str());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
OpDescPtr op_desc = nodes.at(0)->GetOpDesc();
|
|
|
|
@ -112,9 +127,11 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
run_context.stream = run_context.graphStreamList[stream_id];
|
|
|
|
|
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.",
|
|
|
|
|
subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
|
|
|
|
|
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)));
|
|
|
|
|
std::string batch_label;
|
|
|
|
|
(void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label);
|
|
|
|
|
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, "
|
|
|
|
|
"batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id,
|
|
|
|
|
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str());
|
|
|
|
|
for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) {
|
|
|
|
|
GE_CHECK_NOTNULL(*iter);
|
|
|
|
|
Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context);
|
|
|
|
|