From 560aa11b5f81b0d75eba87c08ae3b4610f4aa57f Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Fri, 18 Dec 2020 12:57:25 +0800 Subject: [PATCH] fix ps cache check cnode --- .../backend/session/anf_runtime_algorithm.cc | 5 +++-- .../backend/session/anf_runtime_algorithm.h | 2 +- .../ccsrc/runtime/device/kernel_runtime.cc | 18 +++++++++--------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 0a79c01f0f..52fa279dc2 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -476,7 +476,8 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i return format; } -KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { +KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { MS_EXCEPTION_IF_NULL(anf_node); if (!anf_node->isa()) { MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." @@ -484,7 +485,7 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod } auto input_node = AnfAlgo::GetInputNode(anf_node->cast(), input_idx); MS_EXCEPTION_IF_NULL(input_node); - return VisitKernelWithReturnType(input_node, 0); + return VisitKernelWithReturnType(input_node, 0, visit_nop_node); } std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 4e24af27cc..a6a9c78ebe 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -112,7 +112,7 @@ class AnfRuntimeAlgorithm { // get input format select of anf node static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); // get prev node output width output index - static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); + static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false); // get output format from prev node,input_index is the input index of current node related to prev node static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); // get reshape_type of from the output of input node. diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index a9a13d11c5..35f6cdd91c 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1044,8 +1044,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { continue; } - auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); - auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); + auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); + auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true); MS_EXCEPTION_IF_NULL(input_param.first); MS_EXCEPTION_IF_NULL(input_index.first); auto param_name = input_param.first->fullname_with_scope(); @@ -1053,11 +1053,11 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, continue; } auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); - while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { - input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); + while (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { + input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); MS_EXCEPTION_IF_NULL(input_index.first); } - if ((!input_index.first->isa()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { + if (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() @@ -1085,13 +1085,13 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { continue; } - auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); - auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); + auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true); + auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true); MS_EXCEPTION_IF_NULL(input_param.first); MS_EXCEPTION_IF_NULL(input_index.first); auto param_name = input_param.first->fullname_with_scope(); - while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { - input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); + while (input_index.first->isa() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) { + input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true); MS_EXCEPTION_IF_NULL(input_index.first); } if (input_index.first == first_cache_input_index) {