From b153ca0391fa95a3b311f0b408c0a464214415ab Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 15 Mar 2021 10:21:54 +0800 Subject: [PATCH] fix variable fusion when variable only in subgraph --- ge/graph/load/model_manager/davinci_model.cc | 2 +- ge/graph/passes/variable_op_pass.cc | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 9ce4f595..d33d4b93 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -3904,7 +3904,7 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) } std::vector variable_node_list; - for (ge::NodePtr &node : graph->GetDirectNode()) { + for (ge::NodePtr &node : graph->GetAllNodes()) { if (node == nullptr) { continue; } diff --git a/ge/graph/passes/variable_op_pass.cc b/ge/graph/passes/variable_op_pass.cc index f1843d94..8f33335d 100644 --- a/ge/graph/passes/variable_op_pass.cc +++ b/ge/graph/passes/variable_op_pass.cc @@ -119,8 +119,9 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { return INTERNAL_ERROR; } + auto graph_id = GraphUtils::FindRootGraph(graph)->GetGraphID(); GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(), - GetContext().SessionId(), graph->GetGraphID()); + GetContext().SessionId(), graph_id); if (var_accelerate_ctrl_ == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to run var op pass, the variable accelerate control is null"); @@ -176,7 +177,7 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { GELOGE(INTERNAL_ERROR, "Failed to update the format fusion road for var %s", node->GetName().c_str()); return INTERNAL_ERROR; } - ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph->GetGraphID()); + ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph_id); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to update the graph id for var %s", node->GetName().c_str()); return INTERNAL_ERROR;