|
|
|
@ -578,6 +578,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
|
|
|
|
|
subgraph->SetName("Batch_" + std::to_string(i));
|
|
|
|
|
subgraph->SetParentNode(case_node_);
|
|
|
|
|
subgraph->SetParentGraph(graph);
|
|
|
|
|
(void)AttrUtils::SetBool(subgraph, "_no_reset_name", true);
|
|
|
|
|
graph->AddSubgraph(subgraph->GetName(), subgraph);
|
|
|
|
|
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
|
|
|
|
|
|
|
|
|
@ -599,55 +600,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return PostProcSubgraph(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///
|
|
|
|
|
/// @ingroup ge
|
|
|
|
|
/// @brief Assign parent index for branches.
|
|
|
|
|
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
|
|
|
|
|
/// @return 0: SUCCESS / others: FAILED
|
|
|
|
|
///
|
|
|
|
|
Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
|
|
|
|
|
auto func_desc = case_node_->GetOpDesc();
|
|
|
|
|
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
|
|
|
|
|
auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType());
|
|
|
|
|
if (post_func == nullptr) {
|
|
|
|
|
GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(),
|
|
|
|
|
case_node_->GetType().c_str());
|
|
|
|
|
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS ||
|
|
|
|
|
parse_func_v2 == nullptr) {
|
|
|
|
|
GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(),
|
|
|
|
|
case_node_->GetType().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto &name : func_desc->GetSubgraphInstanceNames()) {
|
|
|
|
|
const auto &subgraph = graph->GetSubgraph(name);
|
|
|
|
|
if (subgraph == nullptr) {
|
|
|
|
|
GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string subgraph_name;
|
|
|
|
|
GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name),
|
|
|
|
|
"Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str());
|
|
|
|
|
|
|
|
|
|
auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph);
|
|
|
|
|
Status ret = FAILED;
|
|
|
|
|
if (post_func != nullptr) {
|
|
|
|
|
ret = post_func(subgraph_name, graph);
|
|
|
|
|
} else if (parse_func_v2 != nullptr) {
|
|
|
|
|
ret = parse_func_v2(subgraph_name.c_str(), graph);
|
|
|
|
|
}
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(),
|
|
|
|
|
case_node_->GetName().c_str(), case_node_->GetType().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|