!10161 fix ps cache check cnode

From: @limingqi107
Reviewed-by: @cristoval,@kisnwang
Signed-off-by: @cristoval
pull/10161/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 65c439335b

@ -476,7 +476,8 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
return format;
}
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx,
bool visit_nop_node) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."
@ -484,7 +485,7 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
}
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(input_node);
return VisitKernelWithReturnType(input_node, 0);
return VisitKernelWithReturnType(input_node, 0, visit_nop_node);
}
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {

@ -112,7 +112,7 @@ class AnfRuntimeAlgorithm {
// get input format select of anf node
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx, bool visit_nop_node = false);
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get reshape_type of from the output of input node.

@ -1044,8 +1044,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
@ -1053,11 +1053,11 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
continue;
}
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
@ -1085,13 +1085,13 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
continue;
}
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
MS_EXCEPTION_IF_NULL(input_param.first);
MS_EXCEPTION_IF_NULL(input_index.first);
auto param_name = input_param.first->fullname_with_scope();
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
MS_EXCEPTION_IF_NULL(input_index.first);
}
if (input_index.first == first_cache_input_index) {

Loading…
Cancel
Save