From aa6dcd9262efa53fd311c80eedc91e23a6fe2783 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 12 Nov 2020 19:59:08 +0800 Subject: [PATCH] Fix Label for dynamic graph. --- ge/graph/build/label_allocator.cc | 31 +++++++++++++++-------- ge/graph/build/model_builder.cc | 4 +-- ge/graph/passes/memcpy_addr_async_pass.cc | 7 ++++- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/ge/graph/build/label_allocator.cc b/ge/graph/build/label_allocator.cc index fad7d0c2..51f31003 100644 --- a/ge/graph/build/label_allocator.cc +++ b/ge/graph/build/label_allocator.cc @@ -32,11 +32,6 @@ Status LabelAllocator::AssignFunctionalLabels() { return INTERNAL_ERROR; } - if (compute_graph_->GetGraphUnknownFlag()) { - GELOGD("Graph[%s] is unknown graph, skip label allocator.", compute_graph_->GetName().c_str()); - return SUCCESS; - } - // Add label task for sub graph. GELOGI("AssignFunctionalLabels start: %s.", compute_graph_->GetName().c_str()); std::set functional_nodes; @@ -62,7 +57,7 @@ Status LabelAllocator::AssignFunctionalLabels() { } (void)AttrUtils::SetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_index); - GELOGI("AssignFunctionalLabels success."); + GELOGI("AssignFunctionalLabels success, Num: %u.", label_index); return SUCCESS; } @@ -72,13 +67,29 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::setGetParentNode(); - if (parent == nullptr) { - GELOGE(INTERNAL_ERROR, "ComputeGraph owner not set: %s.", graph->GetName().c_str()); + if (graph->GetGraphUnknownFlag()) { + GELOGD("Graph[%s] is unknown graph, skip label allocator.", graph->GetName().c_str()); + return true; + } + + NodePtr func_node = graph->GetParentNode(); + if (func_node == nullptr) { + GELOGE(INTERNAL_ERROR, "Parent functional node not set: %s.", graph->GetName().c_str()); return false; } - (void)functional_nodes.insert(parent); // unique functional node. + ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); + if (owner_graph == nullptr) { + GELOGE(INTERNAL_ERROR, "ComputeGraph owner not set: %s.", func_node->GetName().c_str()); + return false; + } + + if (owner_graph->GetGraphUnknownFlag()) { + GELOGD("Graph[%s] is unknown graph, skip label allocator.", owner_graph->GetName().c_str()); + return true; + } + + (void)functional_nodes.insert(func_node); // unique functional node. return true; } } // namespace ge diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 56a5b4dc..f382c24a 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -690,8 +690,8 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); // Assign functional op labels. - label_num_ = 0; - (void)AttrUtils::GetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_num_); + auto root_graph = GraphUtils::FindRootGraph(compute_graph_); + (void)AttrUtils::GetInt(*root_graph, ATTR_MODEL_LABEL_NUM, label_num_); GE_TIMESTAMP_START(AssignMemory); MemoryAssigner mem_assigner(compute_graph_); diff --git a/ge/graph/passes/memcpy_addr_async_pass.cc b/ge/graph/passes/memcpy_addr_async_pass.cc index a9e3f4c4..8bb16286 100755 --- a/ge/graph/passes/memcpy_addr_async_pass.cc +++ b/ge/graph/passes/memcpy_addr_async_pass.cc @@ -25,6 +25,10 @@ namespace ge { Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); + if (graph->GetGraphUnknownFlag()) { + GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); + return SUCCESS; + } int64_t value = 0; rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value); @@ -201,9 +205,10 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr const OutDataAnchorPtr &out_data_anchor, const NodePtr &out_of_user_data) { GELOGD("Start CreateMemcpyAddrAsyncNode."); + static uint32_t new_node_index = 0; OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); - std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC; + std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++); OpDescPtr op_desc = MakeShared(node_name, MEMCPYADDRASYNC); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);