|
|
|
@ -1044,8 +1044,8 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
|
|
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
|
|
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
|
|
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_param.first);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first);
|
|
|
|
|
auto param_name = input_param.first->fullname_with_scope();
|
|
|
|
@ -1053,11 +1053,11 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
|
|
|
|
|
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
|
|
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
|
|
|
|
|
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
|
|
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first);
|
|
|
|
|
}
|
|
|
|
|
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
|
|
|
|
|
if (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
|
|
|
|
|
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
|
|
|
|
|
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
|
|
|
|
|
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
|
|
|
|
@ -1085,13 +1085,13 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
|
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0);
|
|
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1);
|
|
|
|
|
auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
|
|
|
|
|
auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_param.first);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first);
|
|
|
|
|
auto param_name = input_param.first->fullname_with_scope();
|
|
|
|
|
while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) {
|
|
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second);
|
|
|
|
|
while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == "Cast")) {
|
|
|
|
|
input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second, true);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first);
|
|
|
|
|
}
|
|
|
|
|
if (input_index.first == first_cache_input_index) {
|
|
|
|
|