!10161 fix ps cache check cnode

From: @limingqi107
Reviewed-by: @cristoval,@kisnwang
Signed-off-by: @cristoval
pull/10161/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 65c439335b

@ -476,7 +476,8 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
return format; return format;
} }
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) { KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx,
bool visit_nop_node) {
MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) { if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode." MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
@ -484,7 +485,7 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
} }
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx); auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
return VisitKernelWithReturnType(input_node, 0); return VisitKernelWithReturnType(input_node, 0, visit_nop_node);
} }
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {

@ -112,7 +112,7 @@ class AnfRuntimeAlgorithm {
// get input format select of anf node // get input format select of anf node
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx); static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
// get prev node output width output index // get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false);
// get output format from prev node,input_index is the input index of current node related to prev node // get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node. // get reshape_type of from the output of input node.

@ -1044,8 +1044,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue; continue;
} }
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first); MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first); MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope(); auto param_name = input_param.first->fullname_with_scope();
@ -1053,11 +1053,11 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
continue; continue;
} }
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
MS_EXCEPTION_IF_NULL(input_index.first); MS_EXCEPTION_IF_NULL(input_index.first);
} }
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
@ -1085,13 +1085,13 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue; continue;
} }
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first); MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first); MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope(); auto param_name = input_param.first->fullname_with_scope();
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
MS_EXCEPTION_IF_NULL(input_index.first); MS_EXCEPTION_IF_NULL(input_index.first);
} }
if (input_index.first == first_cache_input_index) { if (input_index.first == first_cache_input_index) {

Loading…
Cancel
Save