From 98d6898438e5673b63e41affaa0ec0a835f2b81c Mon Sep 17 00:00:00 2001 From: chujinjin Date: Tue, 15 Dec 2020 20:01:11 +0800 Subject: [PATCH] fix embeddinglookup error on cpu pynative mode --- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index d9f476f0ca..390f9cdf14 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -422,11 +422,14 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; reg_exist = false; } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { - reg_exist = false; + if (ms_context->get_param(MS_CTX_DEVICE_TARGET) != kCPUDevice) { + reg_exist = false; + } } if (op_run_info->op_name == prim::kPrimGatherD->name()) { - auto ms_context = MsContext::GetInstance(); // Gather op needs converting const input to attr on GPU device if (ms_context->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { reg_exist = false;