fix case plugin error

pull/704/head
wjm 4 years ago
parent 19c9b79b03
commit acb3e2f6dd

@ -185,13 +185,17 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) {
const auto &parent_graph = compute_graph->GetParentGraph();
GE_CHECK_NOTNULL(parent_graph);
for (const NodePtr &node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) {
continue;
}
bool flag = false;
(void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag);
if (!flag) {
for (const NodePtr &node : compute_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node->GetOpDesc());
if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) {
continue;
}
node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName());
node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName());
}
}
return PostParseSubgraph(compute_graph, subgraph_name, parent_node);

@ -578,6 +578,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
subgraph->SetName("Batch_" + std::to_string(i));
subgraph->SetParentNode(case_node_);
subgraph->SetParentGraph(graph);
(void)AttrUtils::SetBool(subgraph, "_no_reset_name", true);
graph->AddSubgraph(subgraph->GetName(), subgraph);
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
@ -599,55 +600,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const
}
}
return PostProcSubgraph(graph);
}
///
/// @ingroup ge
/// @brief Assign parent index for branches.
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
/// @return 0: SUCCESS / others: FAILED
///
Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
auto func_desc = case_node_->GetOpDesc();
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType());
if (post_func == nullptr) {
GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(),
case_node_->GetType().c_str());
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS ||
parse_func_v2 == nullptr) {
GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(),
case_node_->GetType().c_str());
return FAILED;
}
}
for (const auto &name : func_desc->GetSubgraphInstanceNames()) {
const auto &subgraph = graph->GetSubgraph(name);
if (subgraph == nullptr) {
GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str());
return FAILED;
}
std::string subgraph_name;
GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name),
"Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str());
auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph);
Status ret = FAILED;
if (post_func != nullptr) {
ret = post_func(subgraph_name, graph);
} else if (parse_func_v2 != nullptr) {
ret = parse_func_v2(subgraph_name.c_str(), graph);
}
if (ret != SUCCESS) {
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(),
case_node_->GetName().c_str(), case_node_->GetType().c_str());
return FAILED;
}
}
return SUCCESS;
}

@ -131,14 +131,6 @@ class MultiBatchClonePass : public GraphPass {
///
Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch);
///
/// @ingroup ge
/// @brief Assign parent index for branches.
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
/// @return 0: SUCCESS / others: FAILED
///
Status PostProcSubgraph(const ComputeGraphPtr &graph);
///
/// @ingroup ge
/// @brief Remove subgraph supend output anchor.

@ -29,6 +29,7 @@
#include "framework/omg/omg_inner_types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/ge_context.h"
#include "graph/passes/data_pass.h"
#include "graph/passes/multi_batch_clone_pass.h"
#include "graph/passes/prune_pass.h"
#include "graph/preprocess/multi_batch_options.h"
@ -1697,6 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) {
if (multi_batch_with_switchn == nullptr) {
PassManager pass_manager;
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass));
return pass_manager.Run(graph);
}
}

Loading…
Cancel
Save