From 137007be880f40dd323a6277f7b630162f524807 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Wed, 13 May 2020 16:46:57 +0800 Subject: [PATCH] only call HasNodeAttr for CNodePtr type --- .../ccsrc/device/ascend/tasksink/task_generator.cc | 3 ++- mindspore/ccsrc/device/kernel_runtime.cc | 3 ++- mindspore/ccsrc/kernel/mng/label_goto.cc | 3 ++- mindspore/ccsrc/kernel/mng/label_set.cc | 3 ++- mindspore/ccsrc/kernel/mng/label_switch.cc | 3 ++- .../pre_activate/ascend/buffer_fusion/buffer_fusion.cc | 10 +++++++--- .../ascend/enhancer/getnext_memcpy_elimination.cc | 2 +- .../ascend/enhancer/insert_memcpy_async_for_getnext.cc | 7 ++++--- .../ccsrc/pre_activate/common/fusion_id_allocator.cc | 9 ++++++++- mindspore/ccsrc/session/anf_runtime_algorithm.cc | 6 +----- mindspore/ccsrc/session/anf_runtime_algorithm.h | 2 +- 11 files changed, 32 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc index 9e50480087..e7b3298b91 100644 --- a/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/device/ascend/tasksink/task_generator.cc @@ -45,7 +45,8 @@ void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressP if (anf_node_ptr->inputs().size() != 2) { MS_LOG(EXCEPTION) << "atomic Addr clean Node Input nodes not equal 2."; } - auto pre_node = anf_node_ptr->inputs()[1]; + MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); + auto pre_node = (anf_node_ptr->inputs()[1])->cast(); // set clean output addr if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) { auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAutomicOutputIndexs); diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 4f10a174b3..00c557f2d5 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -546,7 +546,8 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList if (cnode->inputs().size() != 2) { MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2."; } - auto pre_node = cnode->inputs()[1]; + MS_EXCEPTION_IF_NULL(cnode->inputs()[1]); + auto pre_node = (cnode->inputs()[1])->cast(); // set clean output address if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) { auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAutomicOutputIndexs); diff --git a/mindspore/ccsrc/kernel/mng/label_goto.cc b/mindspore/ccsrc/kernel/mng/label_goto.cc index 674e48fb00..454ce7089f 100644 --- a/mindspore/ccsrc/kernel/mng/label_goto.cc +++ b/mindspore/ccsrc/kernel/mng/label_goto.cc @@ -34,7 +34,8 @@ LabelGotoKernel::~LabelGotoKernel() {} bool LabelGotoKernel::Init(const AnfNodePtr &anf_node) { MS_EXCEPTION_IF_NULL(anf_node); MS_LOG(INFO) << "LabelGotoKernel init"; - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, anf_node)) { + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { MS_LOG(EXCEPTION) << "LabelGotoKernel has no attr label_index"; } auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); diff --git a/mindspore/ccsrc/kernel/mng/label_set.cc b/mindspore/ccsrc/kernel/mng/label_set.cc index 9041867e5f..314aba2d4c 100644 --- a/mindspore/ccsrc/kernel/mng/label_set.cc +++ b/mindspore/ccsrc/kernel/mng/label_set.cc @@ -34,7 +34,8 @@ LabelSetKernel::~LabelSetKernel() {} bool LabelSetKernel::Init(const AnfNodePtr &anf_node) { MS_EXCEPTION_IF_NULL(anf_node); MS_LOG(INFO) << "LabelSetKernel init"; - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, anf_node)) { + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; } auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); diff --git a/mindspore/ccsrc/kernel/mng/label_switch.cc b/mindspore/ccsrc/kernel/mng/label_switch.cc index ac8dafa933..bd45987e51 100644 --- a/mindspore/ccsrc/kernel/mng/label_switch.cc +++ b/mindspore/ccsrc/kernel/mng/label_switch.cc @@ -38,7 +38,8 @@ LabelSwitchKernel::~LabelSwitchKernel() {} bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { MS_EXCEPTION_IF_NULL(anf_node); MS_LOG(INFO) << "LabelSwitchKernel init"; - if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, anf_node)) { + auto cnode = anf_node->cast(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) { MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; } auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index 41e0991065..221bc9695d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -347,9 +347,13 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, auto nodes = TopoSort(kernel_graph->get_return()); for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); - if (AnfAlgo::IsRealCNodeKernel(node) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, node)) { - auto fusion_id = AnfAlgo::GetNodeAttr(node, kOpAttrFusionId); - (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { + auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); + (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); } } } diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc index a39918ecee..f747321721 100644 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/getnext_memcpy_elimination.cc @@ -39,7 +39,7 @@ const AnfNodePtr GetnextMemcpyElimination::Process(const FuncGraphPtr &graph, co } // 1. memcpy has attr kAttrLabelForInsertStreamActive - if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, node)) { + if (!AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, memcpy_cnode)) { MS_LOG(DEBUG) << "node has no label_for_insert_stream_active attr"; return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc index fb8b19047c..01a3f789e7 100644 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.cc @@ -63,13 +63,14 @@ const AnfNodePtr InsertMemcpyAsyncForGetNext::Process(const FuncGraphPtr &func_g return nullptr; } - if (AnfAlgo::HasNodeAttr(kAttrVisited, node)) { + auto cnode = node->cast(); + if (AnfAlgo::HasNodeAttr(kAttrVisited, cnode)) { MS_LOG(DEBUG) << "Node op_name[" << kGetNextOpName << "] has visited."; return nullptr; } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cnode); - return InsertMemcpyAsyncForGetNextOutputs(func_graph, node); + return InsertMemcpyAsyncForGetNextOutputs(func_graph, cnode); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc index 393546278a..2b45fc6579 100644 --- a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc +++ b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc @@ -29,7 +29,14 @@ int32_t FusionIdAllocator::AllocateFusionId() { return fusion_id; } -bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { return AnfAlgo::HasNodeAttr(kAttrFusionId, node); } +bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + return AnfAlgo::HasNodeAttr(kAttrFusionId, cnode); +} int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) { if (HasFusionIdAttr(node)) { diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 2853f6760f..f141ebce6b 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -246,12 +246,8 @@ void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr primitive->EraseAttr(key); } -bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const AnfNodePtr &node) { +bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - MS_LOG(WARNING) << "Only cnode has attr, but this anf is " << node->DebugString(); - return false; - } auto primitive = AnfAlgo::GetCNodePrimitive(node); MS_EXCEPTION_IF_NULL(primitive); return primitive->HasAttr(key); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 78ebf31210..76b57482dd 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -80,7 +80,7 @@ class AnfRuntimeAlgorithm { // set all attrs from 'from' node to 'to' node static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to); // check whether a cnode has the specified attr. - static bool HasNodeAttr(const std::string &key, const AnfNodePtr &node); + static bool HasNodeAttr(const std::string &key, const CNodePtr &node); // delete attr of anf node static void EraseNodeAttr(const std::string &key, AnfNodePtr node); // get the num of input real_kernel(which can be build and run in device)